llvm/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the linalg dialect Fusion on tensors operations pass.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Passes.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <optional>
#include <utility>

namespace mlir {
#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

usingnamespacemlir;
usingnamespacemlir::linalg;

//===---------------------------------------------------------------------===//
// Methods and patterns that fuse elementwise `linalg.generic` operations.
//===---------------------------------------------------------------------===//

/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
/// the `producer` to use in the fused operation given the indexing map of the
/// result of the producer in the consumer.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
    OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
    AffineMap fusedConsumerArgIndexMap) {}

// Checks if the given operand can be dropped, and the remaining operands
// of the fused producer & consumer after the fusion can still compute the
// bounds of the op.
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
    GenericOp producer, GenericOp consumer,
    ArrayRef<OpOperand *> opOperandsToIgnore) {}

/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
    GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {}

/// Conditions for elementwise fusion of generic operations.
bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {}

/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void generateFusedElementwiseOpRegion(
    RewriterBase &rewriter, GenericOp fusedOp,
    AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
    unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {}

FailureOr<mlir::linalg::ElementwiseOpFusionResult>
mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
                                 OpOperand *fusedOperand) {}

namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {};
} // namespace

//===---------------------------------------------------------------------===//
// Methods and patterns that fuse reshape ops with elementwise operations by
// expanding the dimensionality of the elementwise operations.
//===---------------------------------------------------------------------===//

/// Conditions for folding a structured linalg operation with a reshape op by
/// expanding the iteration space dimensionality for tensor operations. These
/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
/// the following fusion pattern.
///
///  Consider
///
///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
///
///  The reshape can be folded into the `linalgOp` if its loop dimensionality
///  is increased to match the result (operand) of the tensor.expand_shape.
///  The indexing_map of the fused tensor in the `linalgOp` and the
///  reassociation map helps compute the indexing maps of the modified op.
///  For the above example, based on the reassociation map it
///  can be concluded that
///
///  - The loop used to access the first dimension of the fused tensor is split
///    into two.
///  - The loop used to access the second dimension of the fused tensor is kept
///    as is.
///  - The loop used to access the third dimension of the fused tensor is split
///    into three.
///
///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
///  op, then
///
///   d0 -> e0, e1
///   d1 -> e2, e3, e4
///   d2 -> e5
///
///  substituting this, the structured op can be rewritten as
///
///  %d = linalg.generic ins(%0, %1 : )
///        indexing_maps =
///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
///
///  Since operands to the linalg generic are now 5D, reshapes can be introduced
///  to make it consistent
///
///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
///
///  The added reshapes are again expanding patterns, so they will get fused
///  with its producers if possible.
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
                                               OpOperand *fusableOpOperand) {}

namespace {
/// Information needed to expand a generic operation to fold the reshape with
/// it.
class ExpansionInfo {};
} // namespace

LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
                                     OpOperand *fusableOpOperand,
                                     ArrayRef<AffineMap> reassociationMaps,
                                     ArrayRef<int64_t> expandedShape,
                                     ArrayRef<int64_t> collapsedShape,
                                     PatternRewriter &rewriter) {}

/// Expanding the body of a linalg operation requires adaptations of the
/// accessed loop indices. Specifically, access of indices in the original
/// operation need to be replaced with linearizations of indices in the expanded
/// op. That requires the shape of the expanded dimensions to be static (at
/// least all but the most significant). For now check that these are all
/// statically sized. Note that this could be extended to handle dynamic case,
/// but the implementation below uses `affine.apply` which seems to have issues
/// when the shapes are not static.
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
                                          const ExpansionInfo &expansionInfo,
                                          PatternRewriter &rewriter) {}

/// Return the indexing map to use in the expanded op for a given the
/// `indexingMap` of the original operation.
static AffineMap
getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
                           const ExpansionInfo &expansionInfo) {}

/// Return the type of the operand/result to use in the expanded op given the
/// type in the original op.
static RankedTensorType getExpandedType(RankedTensorType originalType,
                                        AffineMap indexingMap,
                                        const ExpansionInfo &expansionInfo) {}

/// Returns the reassociation maps to use in the `tensor.expand_shape`
/// operation to convert the operands of the original operation to operands of
/// the expanded operation. The same method is used to compute the
/// `tensor.collapse_shape` used to collapse the result of the expanded
/// op to get the value that can replace all uses of the results of the original
/// op.
static SmallVector<ReassociationIndices>
getReassociationForExpansion(AffineMap indexingMap,
                             const ExpansionInfo &expansionInfo) {}

/// Update the body of an expanded linalg operation having index semantics. The
/// indices of the original operation need to be recovered by linearizing the
/// indices of the correspoding dimensions of the expanded operation. For now it
/// is assumed that the shapes of the expanded operation needed for
/// linearization are static.
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
                                          Location loc, Region &fusedRegion,
                                          const ExpansionInfo &expansionInfo) {}

/// Checks if a single dynamic dimension expanded into multiple dynamic
/// dimensions.
static LogicalResult
validateDynamicDimExpansion(LinalgOp linalgOp,
                            const ExpansionInfo &expansionInfo,
                            PatternRewriter &rewriter) {}

/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
/// that those conditions have been satisfied.
static std::optional<SmallVector<Value>>
fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
                           OpOperand *fusableOpOperand,
                           PatternRewriter &rewriter) {}

namespace {

/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
/// when the reshape op is collapsing dimensions. The dimensionality of the loop
/// in the consumer is expanded.
class FoldWithProducerReshapeOpByExpansion
    : public OpInterfaceRewritePattern<LinalgOp> {};

class FoldPadWithProducerReshapeOpByExpansion
    : public OpRewritePattern<tensor::PadOp> {};

/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
    : public OpRewritePattern<tensor::ExpandShapeOp> {};
} // namespace

//===---------------------------------------------------------------------===//
// Methods and patterns to fuse reshape with linalg.generic operations by
// contraction of dimensions.
//===---------------------------------------------------------------------===//

/// For a given list of indices in the range of the `indexingMap` that are
/// folded, return the indices of the corresponding domain. Return
/// `std::nullopt` on failure. Ensures that all the elements of the returned
/// reassociation are distinct.
static ReassociationIndices
getDomainReassociation(AffineMap indexingMap,
                       ReassociationIndicesRef rangeReassociation) {}

/// For a given `dimSequence`, check if the sequence is conserved in the
/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
/// Non-existence of the sequence returns true as well.
bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
                                          ReassociationIndicesRef dimSequence) {}

bool mlir::linalg::areDimSequencesPreserved(
    ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {}

// Return the list of dimensions of the iteration domain that can be
// collapsed to allow for fusion with the a producer that is an expand_shape
// operation. If all dimensions created by expansion can be collapsed in the
// iteration space then the reshape is defunct.
//
// Example:
//
// ```mlir
// #map = affine_map<(d0, d1) -> (d0, d1)>
// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
// %2 = tensor.empty [..] : tensor<?x4xf32>
// %3 = linalg.generic {
//     indexing_maps = [#map, #map],
//     iterator_types = ["parallel" ,"parallel"]}
//     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
// ```
//
// can be fused by collapsing the dimensions of the iteration space.
//
// ```mlir
// #map = affine_map<(d0) -> (d0)>
// %2 = tensor.empty [..] : tensor<?xf32>
// %3 = linalg.generic {
//     indexing_maps = [#map, #map],
//     iterator_types = ["parallel"]}
//     ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
// ```
//
// In the following example,
//
// ```mlir
// #map0 = affine_map<(d0, d1) -> (d0, d1)>
// #map1 = affine_map<(d0, d1) -> (d1, d0)>
// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
// %2 = tensor.empty [..] : tensor<4x?xf32>
// %2 = linalg.generic {
//     indexing_maps = [#map0, #map1],
//     iterator_types = ["parallel" ,"parallel"]}
//     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
// ```
//
// the reshape cannot be fused with the generic op by collapsing the op
// dimensions since the indexing maps will have to contain mods and divs
// to preserve the accesses pattern. When no dimensions of the iteration
// space are collapsable and empty vector is returned.
static SmallVector<ReassociationIndices>
getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
                                 ArrayRef<ReassociationIndices> reassociation) {}

/// Helper class to carry state while collapsing the `linalg.generic` op.
namespace {
class CollapsingInfo {};
} // namespace

/// Get the iterator types for the collapsed operation given the original
/// iterator types and collapsed dimensions.
static SmallVector<utils::IteratorType>
getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
                            const CollapsingInfo &collapsingInfo) {}

/// Compute the indexing map in the collapsed op that corresponds to the given
/// `indexingMap` of the original operation.
static AffineMap
getCollapsedOpIndexingMap(AffineMap indexingMap,
                          const CollapsingInfo &collapsingInfo) {}

/// Return the `reassociation` indices to use to collapse the operand when the
/// iteration space of a generic op is collapsed.
static SmallVector<ReassociationIndices>
getOperandReassociation(AffineMap indexingMap,
                        const CollapsingInfo &collapsingInfo) {}

/// Get the new value to use for a given `OpOperand` in the collapsed operation.
static Value getCollapsedOpOperand(Location loc, LinalgOp op,
                                   OpOperand *opOperand,
                                   const CollapsingInfo &collapsingInfo,
                                   OpBuilder &builder) {}

/// Modify the `linalg.index` operations in the original generic op, to its
/// value in the collapsed operation.
void generateCollapsedIndexingRegion(Location loc, Block *block,
                                     const CollapsingInfo &collapsingInfo,
                                     ValueRange loopRange,
                                     RewriterBase &rewriter) {}

void collapseOperandsAndResults(LinalgOp op,
                                const CollapsingInfo &collapsingInfo,
                                RewriterBase &rewriter,
                                SmallVectorImpl<Value> &inputOperands,
                                SmallVectorImpl<Value> &outputOperands,
                                SmallVectorImpl<Type> &resultTypes) {}

/// Clone a `LinalgOp` to a collapsed version of same name
template <typename OpTy>
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
                        const CollapsingInfo &collapsingInfo) {}

/// Collapse any `LinalgOp` that does not require any specialization such as
/// indexing_maps, iterator_types, etc.
template <>
LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
                                      const CollapsingInfo &collapsingInfo) {}

/// Collapse a `GenericOp`
template <>
GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
                                        GenericOp origOp,
                                        const CollapsingInfo &collapsingInfo) {}

LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
                           RewriterBase &rewriter) {}

/// Implementation of fusion with reshape operation by collapsing dimensions.
FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
    LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
    RewriterBase &rewriter) {}

namespace {

/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
/// contracting dimensions of the loop.
class FoldWithProducerReshapeOpByCollapsing
    : public OpRewritePattern<GenericOp> {};

class FoldPadWithProducerReshapeOpByCollapsing
    : public OpRewritePattern<tensor::PadOp> {};

/// Pattern to collapse dimensions.
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {};

} // namespace

//===---------------------------------------------------------------------===//
// Methods and patterns that fuse constants with linalg.generic operations.
//===---------------------------------------------------------------------===//

namespace {
/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
/// handle cases where the constant is not single-valued.
class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {};

} // namespace

//===---------------------------------------------------------------------===//
// Miscellaneous patterns that help fusion.
//===---------------------------------------------------------------------===//

namespace {
/// Forces `outs` operands of linalg operations to use `tensor.empty` if the
/// value of the `outs` operand is not used within the op.  This is only
/// implemented for `linalg.generic` operations for now, but should hold for all
/// linalg structured ops.
struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {};

/// Fold linalg.fill into linalg.generic
struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {};
} // namespace

void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
    RewritePatternSet &patterns,
    const ControlFusionFn &controlFoldingReshapes) {}

void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
    RewritePatternSet &patterns,
    const ControlFusionFn &controlFoldingReshapes) {}

void mlir::linalg::populateElementwiseOpsFusionPatterns(
    RewritePatternSet &patterns,
    const ControlFusionFn &controlElementwiseOpsFusion) {}

void mlir::linalg::populateCollapseDimensions(
    RewritePatternSet &patterns,
    const GetCollapsableDimensionsFn &controlCollapseDimensions) {}

//===---------------------------------------------------------------------===//
// Passes
//===---------------------------------------------------------------------===//

namespace {

/// Pass that fuses generic ops on tensors. Used only for testing.
// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
// patterns added here heavily depends on the cost function used. Having an
// opinionated pass of this form is not recommended. Deprecate this pass in
// favor of test passes that check the functionality of each of the patterns
// added here individually.
struct LinalgElementwiseOpFusionPass
    : public impl::LinalgElementwiseOpFusionPassBase<
          LinalgElementwiseOpFusionPass> {};

} // namespace