llvm/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// \file
// This file implements the TOSA Specification:
// https://developer.mlplatform.org/w/tosa/
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"

#include <numeric>

usingnamespacemlir;
usingnamespacemlir::tosa;

#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"

//===----------------------------------------------------------------------===//
// Tosa dialect interface includes.
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"

namespace {
#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"

//===----------------------------------------------------------------------===//
// Dialect Function Inliner Interface.
//===----------------------------------------------------------------------===//
struct TosaInlinerInterface : public DialectInlinerInterface {};

/// This class implements the bytecode interface for the Tosa dialect.
struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {};

} // namespace

//===----------------------------------------------------------------------===//
// TOSA control flow support.
//===----------------------------------------------------------------------===//

/// Returns the while loop body.
SmallVector<Region *> tosa::WhileOp::getLoopRegions() {}

//===----------------------------------------------------------------------===//
// Tosa dialect initialization.
//===----------------------------------------------------------------------===//

void TosaDialect::initialize() {}

Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
                                            Type type, Location loc) {}

//===----------------------------------------------------------------------===//
// Parsers and printers
//===----------------------------------------------------------------------===//

ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
                                        Attribute &attr) {}

void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
                                 Attribute attr) {}

//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//

template <typename T>
static LogicalResult verifyConvOp(T op) {}

LogicalResult tosa::ConstOp::verify() {}

LogicalResult tosa::ArgMaxOp::verify() {}

LogicalResult tosa::AvgPool2dOp::verify() {}

LogicalResult tosa::ClampOp::verify() {}

//===----------------------------------------------------------------------===//
// TOSA Operator Quantization Builders.
//===----------------------------------------------------------------------===//

/// This builder is called on all convolution operators except TransposeConv,
/// which has specialized output shape semantics. The builder also defines the
/// bitwidth of the output given the bit width of the input & weight content.
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                                     Type outputType, Value input, Value weight,
                                     Value bias, DenseI64ArrayAttr pad,
                                     DenseI64ArrayAttr stride,
                                     DenseI64ArrayAttr dilation) {}

/// Handles tosa.transpose_conv2d which has outpad and output shape
/// attributes.
static void buildTransConvOpWithQuantInfo(
    OpBuilder &builder, OperationState &result, Type outputType, Value input,
    Value weight, Value bias, DenseI64ArrayAttr outpad,
    DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {}

/// The tosa.fully_connected op has its own builder as it does not have
/// strides/dilation/padding.
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                                   Type outputType, Value input, Value weight,
                                   Value bias) {}

/// The tosa.matmul op is also intended to be generated where a
/// fully_connected op must be constructed where the weight is not a constant.
/// In this case, the fully_connected op must be expressed using matmul.
/// TODO: Add link to the leglization document explaining this.
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
                                       OperationState &result, Type outputType,
                                       Value a, Value b) {}

/// Both the tosa.avg_pool2d and unary ops use the same
/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
/// has additional parameters not part of the unary ops.
static void
buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                              Type outputType, Value input,
                              DenseArrayAttr kernel, DenseArrayAttr stride,
                              DenseArrayAttr pad, TypeAttr accType) {}

/// This builder is called on single-parameter unary operators that have scale
/// relationship between their input and output, expressed by the
/// UnaryOpQuantizationAttr.
static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
                                      OperationState &result, Type outputType,
                                      Value input) {}

/// This builder is called on TOSA pad operator that needs to create its own
/// OptionalAttr quantization_attr parameter to scale the padding values
/// correctly. No pad_const is interpreted as zero-padding.
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                                    Type outputType, Value input,
                                    Value paddings) {}

/// This builder is called on TOSA pad operator when an explicit pad_const
/// value is passed in. It also optionally constructs quantization_attr.
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
                                                 OperationState &result,
                                                 Type outputType, Value input,
                                                 Value paddings,
                                                 Value padConst) {}

//===----------------------------------------------------------------------===//
// TOSA Operator Return Type Inference.
//===----------------------------------------------------------------------===//

static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
                                           SmallVector<int64_t> &outShape) {}

LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    ArgMaxOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    RFFT2dOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    FFT2dOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    ConcatOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::EqualOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    ValueShapeRange operands, DictionaryAttr attributes,
    OpaqueProperties properties, RegionRange regions,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    FullyConnectedOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult FullyConnectedOp::verify() {}

LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    MatMulOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::PadOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    PadOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::PadOp::verify() {}

static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {}

LogicalResult tosa::SliceOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    SliceOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::SliceOp::verify() {}

LogicalResult tosa::TableOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    TableOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::TableOp::verify() {}

LogicalResult tosa::TileOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    TileOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::TileOp::verify() {}

bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    ReshapeOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

llvm::LogicalResult tosa::ReshapeOp::verify() {}

LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {}

LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    TransposeOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::TransposeOp::verify() {}

LogicalResult TransposeOp::reifyResultShapes(
    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {}

LogicalResult tosa::GatherOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    GatherOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    ResizeOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    ScatterOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

static LogicalResult ReduceInferReturnTypes(
    ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

#define COMPATIBLE_RETURN_TYPES

#define REDUCE_SHAPE_INFER

REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
#undef REDUCE_SHAPE_INFER
COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
#undef COMPATIBLE_RETURN_TYPES

template <typename T>
static LogicalResult verifyReduceOp(T op) {}

LogicalResult tosa::ReduceAllOp::verify() {}
LogicalResult tosa::ReduceAnyOp::verify() {}
LogicalResult tosa::ReduceMaxOp::verify() {}
LogicalResult tosa::ReduceMinOp::verify() {}
LogicalResult tosa::ReduceProdOp::verify() {}
LogicalResult tosa::ReduceSumOp::verify() {}

static LogicalResult NAryInferReturnTypes(
    const ValueShapeRange &operands,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

#define NARY_SHAPE_INFER(OP)

NARY_SHAPE_INFER(tosa::AbsOp)
NARY_SHAPE_INFER(tosa::AddOp)
NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
NARY_SHAPE_INFER(tosa::BitwiseAndOp)
NARY_SHAPE_INFER(tosa::BitwiseOrOp)
NARY_SHAPE_INFER(tosa::BitwiseXorOp)
NARY_SHAPE_INFER(tosa::BitwiseNotOp)
NARY_SHAPE_INFER(tosa::CastOp)
NARY_SHAPE_INFER(tosa::CeilOp)
NARY_SHAPE_INFER(tosa::ClampOp)
NARY_SHAPE_INFER(tosa::ClzOp)
NARY_SHAPE_INFER(tosa::CosOp)
NARY_SHAPE_INFER(tosa::ExpOp)
NARY_SHAPE_INFER(tosa::FloorOp)
NARY_SHAPE_INFER(tosa::GreaterEqualOp)
NARY_SHAPE_INFER(tosa::GreaterOp)
NARY_SHAPE_INFER(tosa::IdentityOp)
NARY_SHAPE_INFER(tosa::IntDivOp)
NARY_SHAPE_INFER(tosa::LogOp)
NARY_SHAPE_INFER(tosa::LogicalAndOp)
NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
NARY_SHAPE_INFER(tosa::LogicalNotOp)
NARY_SHAPE_INFER(tosa::LogicalOrOp)
NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
NARY_SHAPE_INFER(tosa::LogicalXorOp)
NARY_SHAPE_INFER(tosa::MaximumOp)
NARY_SHAPE_INFER(tosa::MinimumOp)
NARY_SHAPE_INFER(tosa::MulOp)
NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
NARY_SHAPE_INFER(tosa::RescaleOp)
NARY_SHAPE_INFER(tosa::ReverseOp)
NARY_SHAPE_INFER(tosa::RsqrtOp)
NARY_SHAPE_INFER(tosa::SinOp)
NARY_SHAPE_INFER(tosa::SelectOp)
NARY_SHAPE_INFER(tosa::SubOp)
NARY_SHAPE_INFER(tosa::TanhOp)
NARY_SHAPE_INFER(tosa::ErfOp)
NARY_SHAPE_INFER(tosa::SigmoidOp)
#undef PRED_SHAPE_INFER

static LogicalResult poolingInferReturnTypes(
    ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
    ArrayRef<int64_t> pad,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult Conv2DOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    Conv2DOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult Conv2DOp::verify() {}

LogicalResult Conv3DOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    Conv3DOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult Conv3DOp::verify() {}

LogicalResult AvgPool2dOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    AvgPool2dOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult MaxPool2dOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    MaxPool2dOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    DepthwiseConv2DOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult DepthwiseConv2DOp::verify() {}

LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    TransposeConv2DOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult IfOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    IfOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

LogicalResult WhileOp::inferReturnTypeComponents(
    MLIRContext *context, ::std::optional<Location> location,
    WhileOp::Adaptor adaptor,
    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {}

std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {}

// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {}

void IfOp::print(OpAsmPrinter &p) {}

LogicalResult ReverseOp::verify() {}

// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {}

static void printInitializationList(OpAsmPrinter &parser,
                                    Block::BlockArgListType blocksArgs,
                                    ValueRange initializers,
                                    StringRef prefix = "") {}

void WhileOp::print(OpAsmPrinter &parser) {}

//===----------------------------------------------------------------------===//
// TOSA Attribute Definitions.
//===----------------------------------------------------------------------===//

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"

//===----------------------------------------------------------------------===//
// TOSA Operator Definitions.
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"