llvm/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

//===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines utilities and common canonicalization patterns for
// reshape operations.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H

#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
#include <optional>

namespace mlir {

ReassociationIndices;
ReassociationIndicesRef;
ReassociationExprs;

/// Attribute name for the ArrayAttr which encodes reassociation indices.
constexpr StringRef getReassociationAttrName() {}

/// Compose reassociation maps that are used in pair of reshape ops where one
/// is a producer and other is the consumer. Only valid to use this method when
/// both the producer and consumer are collapsing dimensions or both are
/// expanding dimensions.
///
/// For example,
///   producerReassociation = [[0, 1], [2], [3, 4]]
///   consumerReassociation = [[0, 1], [2]]
///
/// is folded into
///
///   result = [[0, 1, 2], [3, 4]].
std::optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
    ArrayRef<ReassociationIndices> producerReassociations,
    ArrayRef<ReassociationIndices> consumerReassociations,
    MLIRContext *context);

/// Convert reassociation indices to affine expressions.
SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs(
    MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);

/// Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector<AffineMap, 4>
getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation);

/// Wraps a list of reassociations in an ArrayAttr.
ArrayAttr
getReassociationIndicesAttribute(OpBuilder &b,
                                 ArrayRef<ReassociationIndices> reassociation);

/// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
    ArrayRef<ReassociationExprs> reassociationExprs);

/// Return the reassociations maps to use to reshape given the source type and
/// the target type when possible. Return std::nullopt when this computation
/// failed.
std::optional<SmallVector<ReassociationIndices>>
getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);

/// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if
/// possible.
std::optional<SmallVector<ReassociationIndices>>
getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
                                   ArrayRef<int64_t> targetShape);

/// Return true if the reassociation specification is valid, false otherwise.
/// When false, the `invalidIndex` integer pointer is optionally filled with the
/// index of the offending reassociation map.
bool isReassociationValid(ArrayRef<AffineMap> reassociation,
                          int *invalidIndex = nullptr);

template <typename ReshapeOpTy, typename InverseReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
                                  ArrayRef<Attribute> operands) {}

/// Common verifier for reshape-like types. Fills `expandedType` and
///`collapsedType` with the proper `src` or `result` type.
template <typename Op, typename T>
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
                                            T collapsedType, bool isExpansion) {}

/// Verify that shapes of the reshaped types using following rule:
/// if a dimension in the collapsed type is static, then the corresponding
/// dimensions in the expanded shape should be
///    a) static
///    b) the product should be same as the collaped shape.
LogicalResult reshapeLikeShapesAreCompatible(
    function_ref<LogicalResult(const Twine &)> emitError,
    ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
    ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);

/// Returns true iff the type is a MemRefType and has a non-identity layout.
bool hasNonIdentityLayout(Type type);

enum class ReshapeOpKind {};

/// Pattern to collapse producer/consumer reshape ops that are both collapsing
/// dimensions or are both expanding dimensions.
template <typename ReshapeOpTy, ReshapeOpKind opKind>
struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {};

/// Pattern to compose
/// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`.
/// In that case both `srcType` and `resultType` can be expressed as a function
/// of `intermediateType`.
/// In order to demonstrate the approach, let's assume that `rank(srcType) >
/// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`.
/// In that case, we can iterate over every set of indices in `reassociation_2`
/// and try to find ids of sets of indices in `reassociation_1` that cover it
/// completely.
///
/// Example:
///
///   %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
///     : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
///   %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
///     : tensor<?x?x?x1xi64> into tensor<?x?xi64>
///
/// can be canonicalized into
///
///   %0 = tensor.collapse_shape %arg [[0, 1], [2]]
///     : tensor<?x?x?xi64> into tensor<?x?xi64>
///
/// because [0] and [1] from `expand_shape` reassociation cover completely
/// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of
/// indices, then we fail.
//
/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
/// `reassociation_2` and produce `expand_shape`.
template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
          typename DimOpTy, typename TensorTy>
struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {};

template <typename ExpandOpTy, typename CollapseOpTy>
struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {};

/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular
/// non rank-reducing slice of the collapse_shape output. Try to find which
/// dimensions have been sliced and which dimensions are not sliced (offset = 0,
/// size = dim, size = 1). Note that this conservative as it cannot detect if a
/// dynamic size corresponds to the full tensor dimension or not.
llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
                                         ArrayRef<Range> sliceParams);

/// Determine which dimensions are linearized by a `tensor.collapse_shape` op by
/// inspecting its reassociation indices.
llvm::SmallBitVector
getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);

/// Given the parameters for both operations in a `CollapseShape->ExtractSlice`
/// chain and reified source and result shapes of the CollapseShapeOp, this
/// class provides two functions that assist with directly forming the result
/// of the extract slice by "tiling the CollapseShapeOp by 1".
//// Example:
// clang-format off
/// ```
/// %0 = linalg.generic ... -> tensor<3x7x11x10xf32>
/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
/// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
/// ```
/// This class helps build the below IR to replace %2:
/// ```
/// %dest = tensor.empty() : tensor<10x10xf32>
/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
///    %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
///    %3:3 = arith.delinearize_index %iv into (3, 7, 11)
///
///    // This function takes %3 (multiIndices) and the parameters for the slice below.
///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 
///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
///          tensor<1x10xf32> into tensor<10x10xf32>
///    scf.yield %6 : tensor<10x10xf32>
/// }
/// ```
// clang-format on
class SliceFromCollapseHelper {};

/// Parameters required to simplify a collapsing reshape op with a rank-reducing
/// slice operation. See `getSimplifyCollapseShapeWithRankReducingSliceInfo`.
struct CollapseShapeRankReducingSliceSimplificationInfo {};

/// A collapsing reshape operation can sometimes be simplified or eliminated by
/// inserting a single rank-reducing slice operation between it and the source
/// tensor. The slice op will either take the place of the source, allowing for
/// a new, simpler reshape op to replace the original, or the reshape op will be
/// completely replaced by the slice result.
///
/// This function returns the parameters required to implement this pattern. If
/// the pattern is not applicable, then failure is returned.
///
/// ### Example:
/// ```
/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
///    : tensor<?x1x30x10xf32> to tensor<?x300xf32>
/// ```
/// can be transformed to
/// ```
/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
///                         [0, %dim1, 30, 30]
///                         [1, 1, 1 1]
///   : tensor<?x1x30x10xf32> to tensor<?x30x10xf32>
/// %result = tensor.collapse_shape %tmp [[0], [1, 2]]
///   : tensor<?x30x10xf32> to tensor<?x300xf32>
/// ```
///
/// ### Example:
/// ```
/// %result = tensor.collapse_shape %1 [[0, 1], [2]]
///    : tensor<?x1x30xf32> to tensor<?x30xf32>
/// ```
/// can be transformed to
/// ```
/// %result = tensor.extract_slice %1 [0, 0, 0]
///                                   [%dim2, 1, 30]
///                                   [1, 1, 1]
///    : tensor<?x1x30xf32> to tensor<?x30xf32>
/// ```
FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
getSimplifyCollapseShapeWithRankReducingSliceInfo(
    RankedTensorType sourceType,
    ArrayRef<ReassociationIndices> reassociationIndices);

struct PackingMetadata {};

/// Given a vector of `positions` indices representing desired packing insertion
/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the
/// final positions in the target shape as well as the reshape reassociations.
// Note: This should not be called with a large positions array (or the
// implementation needs to be updated to use an N.log N sort instead of
// repeated N^2 counts).
PackingMetadata computePackingMetadata(int64_t packedRank,
                                       ArrayRef<int64_t> innerDimPos);
} // namespace mlir

#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H