#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include <numeric>
usingnamespacemlir;
usingnamespacemlir::tosa;
template <typename T>
static arith::ConstantOp
createConstFromIntAttribute(Operation *op, const std::string &attrName,
Type requiredAttrType, OpBuilder &rewriter) { … }
static Value createLinalgBodyCalculationForElementwiseOp(
Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
ConversionPatternRewriter &rewriter) { … }
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
int64_t rank) { … }
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
Location loc, ValueRange operands,
int64_t rank) { … }
IndexPool;
static Value createIndex(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, int64_t index) { … }
static Value getTensorDim(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value tensor, int64_t index) { … }
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value tensor,
int64_t index) { … }
static bool operandsAndResultsRanked(Operation *operation) { … }
static std::pair<OpFoldResult, Value>
computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
ValueRange operands, int64_t dim) { … }
static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
computeTargetShape(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, ValueRange operands) { … }
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value operand,
int64_t dim, OpFoldResult targetSize,
Value masterOperand) { … }
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value operand,
ArrayRef<OpFoldResult> targetShape,
ArrayRef<Value> masterOperands) { … }
static SmallVector<Value>
broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, ValueRange operands,
ArrayRef<OpFoldResult> targetShape,
ArrayRef<Value> masterOperands) { … }
static LogicalResult
emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
Operation *operation, ValueRange operands,
ArrayRef<OpFoldResult> targetShape,
const TypeConverter &converter) { … }
static LogicalResult
elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
ConversionPatternRewriter &rewriter,
const TypeConverter &converter) { … }
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) { … }
static Value createLinalgBodyCalculationForReduceOp(Operation *op,
ValueRange args,
Type elementTy,
PatternRewriter &rewriter) { … }
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
PatternRewriter &rewriter) { … }
namespace {
template <typename SrcOp>
class PointwiseConverter : public OpConversionPattern<SrcOp> { … };
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> { … };
class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> { … };
class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> { … };
class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> { … };
template <typename SrcOp>
class IdentityNConverter : public OpRewritePattern<SrcOp> { … };
template <typename SrcOp>
class ReduceConverter : public OpRewritePattern<SrcOp> { … };
class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> { … };
struct TileConverter : public OpConversionPattern<tosa::TileOp> { … };
class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> { … };
class GatherConverter : public OpConversionPattern<tosa::GatherOp> { … };
class TableConverter : public OpRewritePattern<tosa::TableOp> { … };
struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> { … };
struct FFT2dConverter final : OpRewritePattern<FFT2dOp> { … };
}
void mlir::tosa::populateTosaToLinalgConversionPatterns(
const TypeConverter &converter, RewritePatternSet *patterns) { … }