#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include <numeric>
usingnamespacemlir;
usingnamespacemlir::tosa;
#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
namespace {
#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
struct TosaInlinerInterface : public DialectInlinerInterface { … };
struct TosaDialectBytecodeInterface : public BytecodeDialectInterface { … };
}
SmallVector<Region *> tosa::WhileOp::getLoopRegions() { … }
void TosaDialect::initialize() { … }
Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) { … }
ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &attr) { … }
void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr) { … }
template <typename T>
static LogicalResult verifyConvOp(T op) { … }
LogicalResult tosa::ConstOp::verify() { … }
LogicalResult tosa::ArgMaxOp::verify() { … }
LogicalResult tosa::AvgPool2dOp::verify() { … }
LogicalResult tosa::ClampOp::verify() { … }
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input, Value weight,
Value bias, DenseI64ArrayAttr pad,
DenseI64ArrayAttr stride,
DenseI64ArrayAttr dilation) { … }
static void buildTransConvOpWithQuantInfo(
OpBuilder &builder, OperationState &result, Type outputType, Value input,
Value weight, Value bias, DenseI64ArrayAttr outpad,
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) { … }
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input, Value weight,
Value bias) { … }
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
Value a, Value b) { … }
static void
buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input,
DenseArrayAttr kernel, DenseArrayAttr stride,
DenseArrayAttr pad, TypeAttr accType) { … }
static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
Value input) { … }
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input,
Value paddings) { … }
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
OperationState &result,
Type outputType, Value input,
Value paddings,
Value padConst) { … }
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
SmallVector<int64_t> &outShape) { … }
LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ArgMaxOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
RFFT2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
FFT2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ConcatOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
FullyConnectedOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult FullyConnectedOp::verify() { … }
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
MatMulOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::PadOp::verify() { … }
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) { … }
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
SliceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::SliceOp::verify() { … }
LogicalResult tosa::TableOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TableOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::TableOp::verify() { … }
LogicalResult tosa::TileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::TileOp::verify() { … }
bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ReshapeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
llvm::LogicalResult tosa::ReshapeOp::verify() { … }
LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) { … }
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TransposeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::TransposeOp::verify() { … }
LogicalResult TransposeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { … }
LogicalResult tosa::GatherOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
GatherOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ResizeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ScatterOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
#define COMPATIBLE_RETURN_TYPES …
#define REDUCE_SHAPE_INFER …
REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
#undef REDUCE_SHAPE_INFER
COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
#undef COMPATIBLE_RETURN_TYPES
template <typename T>
static LogicalResult verifyReduceOp(T op) { … }
LogicalResult tosa::ReduceAllOp::verify() { … }
LogicalResult tosa::ReduceAnyOp::verify() { … }
LogicalResult tosa::ReduceMaxOp::verify() { … }
LogicalResult tosa::ReduceMinOp::verify() { … }
LogicalResult tosa::ReduceProdOp::verify() { … }
LogicalResult tosa::ReduceSumOp::verify() { … }
static LogicalResult NAryInferReturnTypes(
const ValueShapeRange &operands,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
#define NARY_SHAPE_INFER(OP) …
NARY_SHAPE_INFER(tosa::AbsOp)
NARY_SHAPE_INFER(tosa::AddOp)
NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
NARY_SHAPE_INFER(tosa::BitwiseAndOp)
NARY_SHAPE_INFER(tosa::BitwiseOrOp)
NARY_SHAPE_INFER(tosa::BitwiseXorOp)
NARY_SHAPE_INFER(tosa::BitwiseNotOp)
NARY_SHAPE_INFER(tosa::CastOp)
NARY_SHAPE_INFER(tosa::CeilOp)
NARY_SHAPE_INFER(tosa::ClampOp)
NARY_SHAPE_INFER(tosa::ClzOp)
NARY_SHAPE_INFER(tosa::CosOp)
NARY_SHAPE_INFER(tosa::ExpOp)
NARY_SHAPE_INFER(tosa::FloorOp)
NARY_SHAPE_INFER(tosa::GreaterEqualOp)
NARY_SHAPE_INFER(tosa::GreaterOp)
NARY_SHAPE_INFER(tosa::IdentityOp)
NARY_SHAPE_INFER(tosa::IntDivOp)
NARY_SHAPE_INFER(tosa::LogOp)
NARY_SHAPE_INFER(tosa::LogicalAndOp)
NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
NARY_SHAPE_INFER(tosa::LogicalNotOp)
NARY_SHAPE_INFER(tosa::LogicalOrOp)
NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
NARY_SHAPE_INFER(tosa::LogicalXorOp)
NARY_SHAPE_INFER(tosa::MaximumOp)
NARY_SHAPE_INFER(tosa::MinimumOp)
NARY_SHAPE_INFER(tosa::MulOp)
NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
NARY_SHAPE_INFER(tosa::RescaleOp)
NARY_SHAPE_INFER(tosa::ReverseOp)
NARY_SHAPE_INFER(tosa::RsqrtOp)
NARY_SHAPE_INFER(tosa::SinOp)
NARY_SHAPE_INFER(tosa::SelectOp)
NARY_SHAPE_INFER(tosa::SubOp)
NARY_SHAPE_INFER(tosa::TanhOp)
NARY_SHAPE_INFER(tosa::ErfOp)
NARY_SHAPE_INFER(tosa::SigmoidOp)
#undef PRED_SHAPE_INFER
static LogicalResult poolingInferReturnTypes(
ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
ArrayRef<int64_t> pad,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult Conv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
Conv2DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult Conv2DOp::verify() { … }
LogicalResult Conv3DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
Conv3DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult Conv3DOp::verify() { … }
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
AvgPool2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult MaxPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
MaxPool2dOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
DepthwiseConv2DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult DepthwiseConv2DOp::verify() { … }
LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TransposeConv2DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
LogicalResult WhileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
WhileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { … }
std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() { … }
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { … }
void IfOp::print(OpAsmPrinter &p) { … }
LogicalResult ReverseOp::verify() { … }
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { … }
static void printInitializationList(OpAsmPrinter &parser,
Block::BlockArgListType blocksArgs,
ValueRange initializers,
StringRef prefix = "") { … }
void WhileOp::print(OpAsmPrinter &parser) { … }
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"