//===- TosaMakeBroadcastable.cpp ------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Insert reshape to binary op's input if needed to match rank // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace tosa { #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" } // namespace tosa } // namespace mlir usingnamespacemlir; usingnamespacemlir::tosa; namespace { /// Common code to create the reshape op where necessary to make the rank of the /// operations equal. input1 and input2 will be updated when the rank has /// changed. The caller is expected to use these to rewrite the original /// operator with the RESHAPE now in the graph. /// return failure when (1) no reshape needed, or (2) output_type is specified /// and it has different rank LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, RankedTensorType outputType, Value &input1, Value &input2) { … } template <typename OpTy> struct ConvertTosaOp : public OpRewritePattern<OpTy> { … }; // The MulOp has an extra parameter 'shift' not present in other elementwise // binary ops, that necessitates special handling of its builder. template <> struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> { … }; // The ArithmeticRightShiftOp has an extra parameter 'round' not present in // other elementwise binary ops, that necessitates special handling of its // builder. template <> struct ConvertTosaOp<tosa::ArithmeticRightShiftOp> : public OpRewritePattern<tosa::ArithmeticRightShiftOp> { … }; template <> struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> { … }; } // namespace namespace { /// Pass that enables broadcast by making all input arrays have the same /// number of dimensions. Insert RESHAPE operations to lower rank operand struct TosaMakeBroadcastable : public tosa::impl::TosaMakeBroadcastableBase<TosaMakeBroadcastable> { … }; } // namespace std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() { … }