//===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This header file defines utilities and common canonicalization patterns for // reshape operations. // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" #include <optional> namespace mlir { ReassociationIndices; ReassociationIndicesRef; ReassociationExprs; /// Attribute name for the ArrayAttr which encodes reassociation indices. constexpr StringRef getReassociationAttrName() { … } /// Compose reassociation maps that are used in pair of reshape ops where one /// is a producer and other is the consumer. Only valid to use this method when /// both the producer and consumer are collapsing dimensions or both are /// expanding dimensions. /// /// For example, /// producerReassociation = [[0, 1], [2], [3, 4]] /// consumerReassociation = [[0, 1], [2]] /// /// is folded into /// /// result = [[0, 1, 2], [3, 4]]. std::optional<SmallVector<ReassociationIndices>> composeReassociationIndices( ArrayRef<ReassociationIndices> producerReassociations, ArrayRef<ReassociationIndices> consumerReassociations, MLIRContext *context); /// Convert reassociation indices to affine expressions. SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs( MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices); /// Constructs affine maps out of Array<Array<AffineExpr>>. SmallVector<AffineMap, 4> getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation); /// Wraps a list of reassociations in an ArrayAttr. ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef<ReassociationIndices> reassociation); /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>. SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices( ArrayRef<ReassociationExprs> reassociationExprs); /// Return the reassociations maps to use to reshape given the source type and /// the target type when possible. Return std::nullopt when this computation /// failed. std::optional<SmallVector<ReassociationIndices>> getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType); /// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if /// possible. std::optional<SmallVector<ReassociationIndices>> getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> targetShape); /// Return true if the reassociation specification is valid, false otherwise. /// When false, the `invalidIndex` integer pointer is optionally filled with the /// index of the offending reassociation map. bool isReassociationValid(ArrayRef<AffineMap> reassociation, int *invalidIndex = nullptr); template <typename ReshapeOpTy, typename InverseReshapeOpTy> static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef<Attribute> operands) { … } /// Common verifier for reshape-like types. Fills `expandedType` and ///`collapsedType` with the proper `src` or `result` type. template <typename Op, typename T> static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion) { … } /// Verify that shapes of the reshaped types using following rule: /// if a dimension in the collapsed type is static, then the corresponding /// dimensions in the expanded shape should be /// a) static /// b) the product should be same as the collaped shape. LogicalResult reshapeLikeShapesAreCompatible( function_ref<LogicalResult(const Twine &)> emitError, ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape, ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape); /// Returns true iff the type is a MemRefType and has a non-identity layout. bool hasNonIdentityLayout(Type type); enum class ReshapeOpKind { … }; /// Pattern to collapse producer/consumer reshape ops that are both collapsing /// dimensions or are both expanding dimensions. template <typename ReshapeOpTy, ReshapeOpKind opKind> struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> { … }; /// Pattern to compose /// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`. /// In that case both `srcType` and `resultType` can be expressed as a function /// of `intermediateType`. /// In order to demonstrate the approach, let's assume that `rank(srcType) > /// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`. /// In that case, we can iterate over every set of indices in `reassociation_2` /// and try to find ids of sets of indices in `reassociation_1` that cover it /// completely. /// /// Example: /// /// %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] /// : tensor<?x?x?xi64> into tensor<?x?x?x1xi64> /// %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] /// : tensor<?x?x?x1xi64> into tensor<?x?xi64> /// /// can be canonicalized into /// /// %0 = tensor.collapse_shape %arg [[0, 1], [2]] /// : tensor<?x?x?xi64> into tensor<?x?xi64> /// /// because [0] and [1] from `expand_shape` reassociation cover completely /// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of /// indices, then we fail. // /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1` /// `reassociation_2` and produce `expand_shape`. template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy, typename DimOpTy, typename TensorTy> struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> { … }; template <typename ExpandOpTy, typename CollapseOpTy> struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> { … }; /// The input parameters `offsets`, `sizes`, `strides` specify a rectangular /// non rank-reducing slice of the collapse_shape output. Try to find which /// dimensions have been sliced and which dimensions are not sliced (offset = 0, /// size = dim, size = 1). Note that this conservative as it cannot detect if a /// dynamic size corresponds to the full tensor dimension or not. llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape, ArrayRef<Range> sliceParams); /// Determine which dimensions are linearized by a `tensor.collapse_shape` op by /// inspecting its reassociation indices. llvm::SmallBitVector getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices); /// Given the parameters for both operations in a `CollapseShape->ExtractSlice` /// chain and reified source and result shapes of the CollapseShapeOp, this /// class provides two functions that assist with directly forming the result /// of the extract slice by "tiling the CollapseShapeOp by 1". //// Example: // clang-format off /// ``` /// %0 = linalg.generic ... -> tensor<3x7x11x10xf32> /// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32> /// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32> /// ``` /// This class helps build the below IR to replace %2: /// ``` /// %dest = tensor.empty() : tensor<10x10xf32> /// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> { /// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv) /// %3:3 = arith.delinearize_index %iv into (3, 7, 11) /// /// // This function takes %3 (multiIndices) and the parameters for the slice below. /// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] : /// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32> /// /// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : /// tensor<1x1x1x10xf32> into tensor<1x10xf32> /// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] : /// tensor<1x10xf32> into tensor<10x10xf32> /// scf.yield %6 : tensor<10x10xf32> /// } /// ``` // clang-format on class SliceFromCollapseHelper { … }; /// Parameters required to simplify a collapsing reshape op with a rank-reducing /// slice operation. See `getSimplifyCollapseShapeWithRankReducingSliceInfo`. struct CollapseShapeRankReducingSliceSimplificationInfo { … }; /// A collapsing reshape operation can sometimes be simplified or eliminated by /// inserting a single rank-reducing slice operation between it and the source /// tensor. The slice op will either take the place of the source, allowing for /// a new, simpler reshape op to replace the original, or the reshape op will be /// completely replaced by the slice result. /// /// This function returns the parameters required to implement this pattern. If /// the pattern is not applicable, then failure is returned. /// /// ### Example: /// ``` /// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]] /// : tensor<?x1x30x10xf32> to tensor<?x300xf32> /// ``` /// can be transformed to /// ``` /// %tmp = tensor.extract_slice %0 [0, 0, 0, 0] /// [0, %dim1, 30, 30] /// [1, 1, 1 1] /// : tensor<?x1x30x10xf32> to tensor<?x30x10xf32> /// %result = tensor.collapse_shape %tmp [[0], [1, 2]] /// : tensor<?x30x10xf32> to tensor<?x300xf32> /// ``` /// /// ### Example: /// ``` /// %result = tensor.collapse_shape %1 [[0, 1], [2]] /// : tensor<?x1x30xf32> to tensor<?x30xf32> /// ``` /// can be transformed to /// ``` /// %result = tensor.extract_slice %1 [0, 0, 0] /// [%dim2, 1, 30] /// [1, 1, 1] /// : tensor<?x1x30xf32> to tensor<?x30xf32> /// ``` FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> getSimplifyCollapseShapeWithRankReducingSliceInfo( RankedTensorType sourceType, ArrayRef<ReassociationIndices> reassociationIndices); struct PackingMetadata { … }; /// Given a vector of `positions` indices representing desired packing insertion /// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the /// final positions in the target shape as well as the reshape reassociations. // Note: This should not be called with a large positions array (or the // implementation needs to be updated to use an N.log N sort instead of // repeated N^2 counts). PackingMetadata computePackingMetadata(int64_t packedRank, ArrayRef<int64_t> innerDimPos); } // namespace mlir #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H