llvm/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp

//===- TosaMakeBroadcastable.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Insert reshape to binary op's input if needed to match rank
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace tosa {
#define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
} // namespace tosa
} // namespace mlir

usingnamespacemlir;
usingnamespacemlir::tosa;

namespace {

/// Common code to create the reshape op where necessary to make the rank of the
/// operations equal. input1 and input2 will be updated when the rank has
/// changed. The caller is expected to use these to rewrite the original
/// operator with the RESHAPE now in the graph.
/// return failure when (1) no reshape needed, or (2) output_type is specified
/// and it has different rank
LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
                                   RankedTensorType outputType, Value &input1,
                                   Value &input2) {}

template <typename OpTy>
struct ConvertTosaOp : public OpRewritePattern<OpTy> {};

// The MulOp has an extra parameter 'shift' not present in other elementwise
// binary ops, that necessitates special handling of its builder.
template <>
struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {};

// The ArithmeticRightShiftOp has an extra parameter 'round' not present in
// other elementwise binary ops, that necessitates special handling of its
// builder.
template <>
struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
    : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {};

template <>
struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {};
} // namespace

namespace {
/// Pass that enables broadcast by making all input arrays have the same
/// number of dimensions. Insert RESHAPE operations to lower rank operand
struct TosaMakeBroadcastable
    : public tosa::impl::TosaMakeBroadcastableBase<TosaMakeBroadcastable> {};
} // namespace

std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {}