//===- TosaReduceTransposes.cpp -------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // ---------- // Motivation: // ---------- // Some legalization pathways introduce redundant tosa.TRANSPOSE // operations that result in avoidable data movement. For example, // PyTorch -> TOSA contains a lot of unnecessary transposes due // to conversions between NCHW and NHWC. // We wish to remove all the ones that we can, since in general // it is possible to remove the overwhelming majority. // ------------------- // High-Level Overview: // ------------------- // The pass works through the transpose operators in the program. It begins at // some transpose operator with an associated permutations tensor. It traverses // upwards through the dependencies of this transpose and verifies that we // encounter only operators with the TosaElementwiseOperator trait and terminate // in either constants, reshapes, or transposes. // We then evaluate whether there are any additional restrictions (the // transposes it terminates in must invert the one we began at, and the reshapes // must be ones in which we can fold the transpose into), and then we hoist the // transpose through the intervening operators, folding it at the constants, // reshapes, and transposes. // Finally, we ensure that we do not need both the transposed form (the form // that had the transpose hoisted through it) and the untransposed form (which // it was prior), by analyzing the usages of those dependent operators of a // given transpose we are attempting to hoist and replace. // If they are such that it would require both forms to be necessary, then we do // not replace the hoisted transpose, causing the new chain to be dead. // Otherwise, we do and the old chain (untransposed form) becomes dead. Only one // chain will ever then be live, resulting in no duplication. // We then perform a simple one-pass DCE, so no canonicalization is necessary. // ----------- // Future Work: // ----------- // (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across // hoisted // transposes with different permutation tensors. // (2) Expand the class of foldable upstream ReshapeOp we permit beyond // N -> 1x1x...x1xNx1x...x1x1. // (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond // those that form the identity. // (4) Add support for more instructions besides TosaElementwiseOperator as // the intervening ones (for example, the reduce_* operators). // (5) Support hoisting transposes up to an input parameter. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/Iterators.h" #include "mlir/IR/Matchers.h" #include "llvm/ADT/TypeSwitch.h" #include <memory> #include <set> #include <stack> namespace mlir { namespace tosa { #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" } // namespace tosa } // namespace mlir usingnamespacemlir; usingnamespacemlir::tosa; //===----------------------------------------------------------------------===// // TOSA Reduce Transposes Pass. //===----------------------------------------------------------------------===// namespace { struct TosaReduceTransposes final : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> { … }; std::optional<DenseElementsAttr> TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms) { … } // The SetVector should only contain ConstOp, ReshapeOp, TransposeOp // as the sources of the data dependencies, and TosaElementWiseOperator // after that, if the function returns true. bool TosaReduceTransposes::collectFanIn(Operation *op, SetVector<Operation *> &collected) { … } // Assuming that due to the verification of TransposeOp perms arrays are // permutations of 0 - perms.size() - 1. bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1, ArrayRef<int32_t> perms2) { … } // Primary overload for those with TosaElementwiseOperator trait. // The other ones handle the case of the operations that occur at the // roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp). std::optional<Value> TosaReduceTransposes::buildMappedToValue( Operation *op, const DenseMap<Value, Value> &valuesMap, IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { … } std::optional<Value> TosaReduceTransposes::buildMappedToValue( TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap, IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { … } std::optional<Value> TosaReduceTransposes::buildMappedToValue( ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap, IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { … } std::optional<Value> TosaReduceTransposes::buildMappedToValue( ConstOp constOp, const DenseMap<Value, Value> &valuesMap, IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { … } bool TosaReduceTransposes::convertDependentOps( SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap, IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { … } bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies( Operation *user, std::set<TransposeOp> &validTransposes, std::vector<std::pair<TransposeOp, SetVector<Operation *>>> &transposeInfo) { … } // Dependencies are valid for an operation if none of them occur outside // of the proper fan-in cones of the hoisted TransposeOp with the same perms // that we can replace. Described in more detail within. bool TosaReduceTransposes::dependenciesAreValid( ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps, std::set<TransposeOp> &validTransposes, std::vector<std::pair<TransposeOp, SetVector<Operation *>>> &transposeInfo) { … } // Getting the set of TransposeOp that we can replace without causing // the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being // dead code. This is done by iterating the set until convergence, since // if you are used outside your own fan-in cone, it's possible to be used // in another fan-in cone of a TransposeOp that is being replaced -- unless // we find that that one has a usage outside of it too. std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements( ArrayRef<int32_t> perms, std::vector<std::pair<TransposeOp, SetVector<Operation *>>> &transposeInfo) { … } void TosaReduceTransposes::runOnOperation() { … } } // namespace