llvm/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp

//===- TosaReduceTransposes.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// ----------
// Motivation:
// ----------

// Some legalization pathways introduce redundant tosa.TRANSPOSE
// operations that result in avoidable data movement. For example,
// PyTorch -> TOSA contains a lot of unnecessary transposes due
// to conversions between NCHW and NHWC.

// We wish to remove all the ones that we can, since in general
// it is possible to remove the overwhelming majority.

// -------------------
// High-Level Overview:
// -------------------

// The pass works through the transpose operators in the program. It begins at
// some transpose operator with an associated permutations tensor. It traverses
// upwards through the dependencies of this transpose and verifies that we
// encounter only operators with the TosaElementwiseOperator trait and terminate
// in either constants, reshapes, or transposes.

// We then evaluate whether there are any additional restrictions (the
// transposes it terminates in must invert the one we began at, and the reshapes
// must be ones in which we can fold the transpose into), and then we hoist the
// transpose through the intervening operators, folding it at the constants,
// reshapes, and transposes.

// Finally, we ensure that we do not need both the transposed form (the form
// that had the transpose hoisted through it) and the untransposed form (which
// it was prior), by analyzing the usages of those dependent operators of a
// given transpose we are attempting to hoist and replace.

// If they are such that it would require both forms to be necessary, then we do
// not replace the hoisted transpose, causing the new chain to be dead.
// Otherwise, we do and the old chain (untransposed form) becomes dead. Only one
// chain will ever then be live, resulting in no duplication.

// We then perform a simple one-pass DCE, so no canonicalization is necessary.

// -----------
// Future Work:
// -----------

// (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across
// hoisted
//     transposes with different permutation tensors.

// (2) Expand the class of foldable upstream ReshapeOp we permit beyond
//     N -> 1x1x...x1xNx1x...x1x1.

// (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
//     those that form the identity.

// (4) Add support for more instructions besides TosaElementwiseOperator as
//     the intervening ones (for example, the reduce_* operators).

// (5) Support hoisting transposes up to an input parameter.

//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/TypeSwitch.h"
#include <memory>
#include <set>
#include <stack>

namespace mlir {
namespace tosa {
#define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
} // namespace tosa
} // namespace mlir

usingnamespacemlir;
usingnamespacemlir::tosa;

//===----------------------------------------------------------------------===//
// TOSA Reduce Transposes Pass.
//===----------------------------------------------------------------------===//

namespace {

struct TosaReduceTransposes final
    : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {};

std::optional<DenseElementsAttr>
TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
                                              ArrayRef<int32_t> perms) {}

// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
// as the sources of the data dependencies, and TosaElementWiseOperator
// after that, if the function returns true.
bool TosaReduceTransposes::collectFanIn(Operation *op,
                                        SetVector<Operation *> &collected) {}

// Assuming that due to the verification of TransposeOp perms arrays are
// permutations of 0 - perms.size() - 1.
bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1,
                                                   ArrayRef<int32_t> perms2) {}

// Primary overload for those with TosaElementwiseOperator trait.
// The other ones handle the case of the operations that occur at the
// roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
    Operation *op, const DenseMap<Value, Value> &valuesMap,
    IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {}

std::optional<Value> TosaReduceTransposes::buildMappedToValue(
    TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
    IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {}

std::optional<Value> TosaReduceTransposes::buildMappedToValue(
    ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
    IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {}

std::optional<Value> TosaReduceTransposes::buildMappedToValue(
    ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
    IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {}

bool TosaReduceTransposes::convertDependentOps(
    SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
    IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {}

bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
    Operation *user, std::set<TransposeOp> &validTransposes,
    std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
        &transposeInfo) {}

// Dependencies are valid for an operation if none of them occur outside
// of the proper fan-in cones of the hoisted TransposeOp with the same perms
// that we can replace. Described in more detail within.
bool TosaReduceTransposes::dependenciesAreValid(
    ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
    std::set<TransposeOp> &validTransposes,
    std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
        &transposeInfo) {}

// Getting the set of TransposeOp that we can replace without causing
// the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
// dead code. This is done by iterating the set until convergence, since
// if you are used outside your own fan-in cone, it's possible to be used
// in another fan-in cone of a TransposeOp that is being replaced -- unless
// we find that that one has a usage outside of it too.
std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
    ArrayRef<int32_t> perms,
    std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
        &transposeInfo) {}

void TosaReduceTransposes::runOnOperation() {}

} // namespace