llvm/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// These rewriters lower from the Tosa to the Linalg dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"

#include <numeric>

usingnamespacemlir;
usingnamespacemlir::tosa;

template <typename T>
static arith::ConstantOp
createConstFromIntAttribute(Operation *op, const std::string &attrName,
                            Type requiredAttrType, OpBuilder &rewriter) {}

static Value createLinalgBodyCalculationForElementwiseOp(
    Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
    ConversionPatternRewriter &rewriter) {}

static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
                        int64_t rank) {}

static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
                                           Location loc, ValueRange operands,
                                           int64_t rank) {}

IndexPool;

// Emit an 'arith.constant' op for the given index if it has not been created
// yet, or return an existing constant. This will prevent an excessive creation
// of redundant constants, easing readability of emitted code for unit tests.
static Value createIndex(PatternRewriter &rewriter, Location loc,
                         IndexPool &indexPool, int64_t index) {}

static Value getTensorDim(PatternRewriter &rewriter, Location loc,
                          IndexPool &indexPool, Value tensor, int64_t index) {}

static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
                                       IndexPool &indexPool, Value tensor,
                                       int64_t index) {}

static bool operandsAndResultsRanked(Operation *operation) {}

// Compute the runtime dimension size for dimension 'dim' of the output by
// inspecting input 'operands', all of which are expected to have the same rank.
// This function returns a pair {targetSize, masterOperand}.
//
// The runtime size of the output dimension is returned either as a statically
// computed attribute or as a runtime SSA value.
//
// If the target size was inferred directly from one dominating operand, that
// operand is returned in 'masterOperand'. If the target size is inferred from
// multiple operands, 'masterOperand' is set to nullptr.
static std::pair<OpFoldResult, Value>
computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
                  ValueRange operands, int64_t dim) {}

// Compute the runtime output size for all dimensions. This function returns
// a pair {targetShape, masterOperands}.
static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
computeTargetShape(PatternRewriter &rewriter, Location loc,
                   IndexPool &indexPool, ValueRange operands) {}

static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
                                       IndexPool &indexPool, Value operand,
                                       int64_t dim, OpFoldResult targetSize,
                                       Value masterOperand) {}

static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
                                        IndexPool &indexPool, Value operand,
                                        ArrayRef<OpFoldResult> targetShape,
                                        ArrayRef<Value> masterOperands) {}

static SmallVector<Value>
broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
                           IndexPool &indexPool, ValueRange operands,
                           ArrayRef<OpFoldResult> targetShape,
                           ArrayRef<Value> masterOperands) {}

static LogicalResult
emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
                           Operation *operation, ValueRange operands,
                           ArrayRef<OpFoldResult> targetShape,
                           const TypeConverter &converter) {}

static LogicalResult
elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
                                 ConversionPatternRewriter &rewriter,
                                 const TypeConverter &converter) {}

// Returns the constant initial value for a given reduction operation. The
// attribute type varies depending on the element type required.
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
                                               PatternRewriter &rewriter) {}

// Creates the body calculation for a reduction. The operations vary depending
// on the input type.
static Value createLinalgBodyCalculationForReduceOp(Operation *op,
                                                    ValueRange args,
                                                    Type elementTy,
                                                    PatternRewriter &rewriter) {}

// Performs the match and rewrite for reduction operations. This includes
// declaring a correctly sized initial value, and the linalg.generic operation
// that reduces across the specified axis.
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
                                                 PatternRewriter &rewriter) {}

namespace {

template <typename SrcOp>
class PointwiseConverter : public OpConversionPattern<SrcOp> {};

class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {};

// Handle the resize case where the input is a 1x1 image. This case
// can entirely avoiding having extract operations which target much
// more difficult to optimize away.
class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {};

// TOSA resize with width or height of 1 may be broadcasted to a wider
// dimension. This is done by materializing a new tosa.resize without
// the broadcasting behavior, and an explicit broadcast afterwards.
class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {};

class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {};

// At the codegen level any identity operations should be removed. Any cases
// where identity is load-bearing (e.g. cross device computation) should be
// handled before lowering to codegen.
template <typename SrcOp>
class IdentityNConverter : public OpRewritePattern<SrcOp> {};

template <typename SrcOp>
class ReduceConverter : public OpRewritePattern<SrcOp> {};

class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {};

// This converter translate a tile operation to a reshape, broadcast, reshape.
// The first reshape minimally expands each tiled dimension to include a
// proceding size-1 dim. This dim is then broadcasted to the appropriate
// multiple.
struct TileConverter : public OpConversionPattern<tosa::TileOp> {};

// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
// op, producing two output buffers.
//
// The first output buffer contains the index of the found maximum value. It is
// initialized to 0 and is resulting integer type.
//
// The second output buffer contains the maximum value found. It is initialized
// to the minimum representable value of the input element type. After being
// populated by indexed_generic, this buffer is disgarded as only the index is
// requested.
//
// The indexed_generic op updates both the maximum value and index if the
// current value exceeds the running max.
class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {};

class GatherConverter : public OpConversionPattern<tosa::GatherOp> {};

// Lowerings the TableOp to a series of gathers and numerica operations. This
// includes interpolation between the high/low values. For the I8 varient, this
// simplifies to a single gather operation.
class TableConverter : public OpRewritePattern<tosa::TableOp> {};

struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {};

struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {};

} // namespace

void mlir::tosa::populateTosaToLinalgConversionPatterns(
    TypeConverter &converter, RewritePatternSet *patterns) {}