llvm/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns/pass to remove usage of unit-extent dimensions
// to specify broadcasting in favor of more canonical representation of the
// computation
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Passes.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE

usingnamespacemlir;
usingnamespacemlir::linalg;

namespace {
/// Pattern to move init operands to ins when all the loops are parallel and
/// blockArgument corresponding to init is used in the region. This is a fix-up
/// when unit reduction dimensions are all folded away. In this context, it
/// becomes a elementwise generic op. E.g., it converts
///
///  %0 = tensor.empty() : tensor<1x1xf32>
///  %1 = linalg.fill
///    ins(%cst : f32)
///    outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
///  %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
///                                        affine_map<(d0) -> (0, d0)>],
///                       iterator_types = ["parallel"]}
///    ins(%arg0 : tensor<1x?x1x1xf32>)
///    outs(%1 : tensor<1x1xf32>) {
///  ^bb0(%in: f32, %out: f32):
///    %3 = arith.addf %in, %out : f32
///    linalg.yield %3 : f32
///  } -> tensor<1x1xf32>
///
///  into
///
///  %0 = tensor.empty() : tensor<1x1xf32>
///  %1 = linalg.fill
///    ins(%cst : f32)
///    outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
///  %2 = tensor.empty() : tensor<1x1xf32>
///  %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
///                                        affine_map<(d0) -> (0, d0)>,
///                                        affine_map<(d0) -> (0, d0)>],
///                       iterator_types = ["parallel"]}
///   ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
///   outs(%2 : tensor<1x1xf32>) {
///  ^bb0(%in: f32, %in_0: f32, %out: f32):
///    %4 = arith.addf %in, %in_0 : f32
///    linalg.yield %4 : f32
///  } -> tensor<1x1xf32>
struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {};
} // namespace

//===---------------------------------------------------------------------===//
// Drop loops that are unit-extents within Linalg operations.
//===---------------------------------------------------------------------===//

/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
/// broadcasting. For example,
///
/// ```mlir
/// #accesses = [
///   affine_map<(d0, d1) -> (0, d1)>,
///   affine_map<(d0, d1) -> (d0, 0)>,
///   affine_map<(d0, d1) -> (d0, d1)>
/// ]
///
/// #trait = {
///   indexing_maps = #accesses,
///   iterator_types = ["parallel", "parallel"],
///   library_call = "some_external_fn"
/// }
///
/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
/// tensor<5x5xf32>
/// {
///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
///        tensor<5xf32> into tensor<1x5xf32>
///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
///        tensor<5xf32> into tensor<5x1xf32>
///   %2 = linalg.generic #trait %0, %1 {
///        ^bb0(%arg2: f32, %arg3: f32):
///          %3 = arith.addf %arg2, %arg3 : f32
///          linalg.yield %3 : f32
///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
///   return %2 : tensor<5x5xf32>
/// }
///
/// would canonicalize to
///
/// ```mlir
/// #accesses = [
///   affine_map<(d0, d1) -> (d1)>,
///   affine_map<(d0, d1) -> (d0)>,
///   affine_map<(d0, d1) -> (d0, d1)>
/// ]
///
/// #trait = {
///   indexing_maps = #accesses,
///   iterator_types = ["parallel", "parallel"],
///   library_call = "some_external_fn"
/// }
///
/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
/// tensor<5x5xf32>
/// {
///   %0 = linalg.generic #trait %arg0, %arg1 {
///        ^bb0(%arg2: f32, %arg3: f32):
///          %3 = arith.addf %arg2, %arg3 : f32
///          linalg.yield %3 : f32
///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
///   return %0 : tensor<5x5xf32>
/// }

/// Update the index accesses of linalg operations having index semantics.
static void
replaceUnitDimIndexOps(GenericOp genericOp,
                       const llvm::SmallDenseSet<unsigned> &unitDims,
                       RewriterBase &rewriter) {}

/// Expand the given `value` so that the type matches the type of `origDest`.
/// The `reassociation` is used when `rankReductionStrategy` is set to
/// `RankReductionStrategy::ReassociativeReshape`.
static Value
expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
            ArrayRef<ReassociationIndices> reassociation,
            ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {}

/// Collapse the given `value` so that the type matches the type of
/// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
/// set to `RankReductionStrategy::ReassociativeReshape`.
static Value collapseValue(
    RewriterBase &rewriter, Location loc, Value operand,
    ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
    ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {}

/// Compute the modified metadata for an operands of operation
/// whose unit dims are being dropped. Return the new indexing map
/// to use, the shape of the operand in the replacement op
/// and the `reassocation` to use to go from original operand shape
/// to modified operand shape.
struct UnitExtentReplacementInfo {};
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
    MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
    llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
    ArrayRef<AffineExpr> dimReplacements) {}

FailureOr<DropUnitDimsResult>
linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
                     const ControlDropUnitDims &options) {}

namespace {
struct DropUnitDims : public OpRewritePattern<GenericOp> {};
} // namespace

//===---------------------------------------------------------------------===//
// Drop dimensions that are unit-extents within tensor operations.
//===---------------------------------------------------------------------===//

namespace {
struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {};
} // namespace

namespace {
/// Convert `extract_slice` operations to rank-reduced versions.
struct RankReducedExtractSliceOp
    : public OpRewritePattern<tensor::ExtractSliceOp> {};

/// Convert `insert_slice` operations to rank-reduced versions.
/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
template <typename InsertOpTy>
struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {};
} // namespace

/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
static void
populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
                                              ControlDropUnitDims &options) {}

static void
populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
                                            ControlDropUnitDims &options) {}

void mlir::linalg::populateFoldUnitExtentDimsPatterns(
    RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {}

void mlir::linalg::populateMoveInitOperandsToInputPattern(
    RewritePatternSet &patterns) {}

namespace {
/// Pass that removes unit-extent dims within generic ops.
struct LinalgFoldUnitExtentDimsPass
    : public impl::LinalgFoldUnitExtentDimsPassBase<
          LinalgFoldUnitExtentDimsPass> {};

} // namespace

namespace {

/// Returns reassociation indices for collapsing/expanding a
/// tensor of rank `rank` at position `pos`.
static SmallVector<ReassociationIndices>
getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {}

/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
/// If `pos < 0`, then don't collapse.
static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
                                    int64_t pos) {}

/// Base class for all rank reduction patterns for contraction ops
/// with unit dimensions.  All patterns should convert one named op
/// to another named op.  Intended to reduce only one iteration space dim
/// at a time.
/// Reducing multiple dims will happen with recusive application of
/// pattern rewrites.
template <typename FromOpTy, typename ToOpTy>
struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {};

/// Patterns for unbatching batched contraction ops
template <typename FromOpTy, typename ToOpTy>
struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {};

/// Patterns for reducing non-batch dimensions
template <typename FromOpTy, typename ToOpTy>
struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {};

} // namespace

void mlir::linalg::populateContractionOpRankReducingPatterns(
    RewritePatternSet &patterns) {}