llvm/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

//===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites as 1->N patterns.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"

#include <cassert>
#include <cstdint>
#include <functional>
#include <optional>
#include <type_traits>

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"

#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE

usingnamespacemlir;
usingnamespacemlir::vector;

template <typename IntType>
static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {}

// Helper to find an index in an affine map.
static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {}

namespace {

/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
//
// Example:
//
//  The following MLIR with cancelling ShapeCastOps:
//
//   %0 = source : vector<5x4x2xf32>
//   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
//   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
//   %3 = user %2 : vector<5x4x2xf32>
//
//  Should canonicalize to the following:
//
//   %0 = source : vector<5x4x2xf32>
//   %1 = user %0 : vector<5x4x2xf32>
//
struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {};

/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
/// Ex:
/// ```
///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
///   %1 = vector.multi_reduction add, %0 [1]
///     : vector<8x32x16xf32> to vector<8x16xf32>
/// ```
/// Gets converted to:
/// ```
///   %1 = vector.contract {indexing_maps = [
///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
///         affine_map<(d0, d1, d2) -> (d0, d1)>],
///    iterator_types = ["parallel", "parallel", "reduction"],
///    kind = add} %0, %arg1, %cst_f0
///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
///  ```
struct MultiReduceToContract
    : public OpRewritePattern<vector::MultiDimReductionOp> {};

/// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
/// Ex:
/// ```
///   %0 = vector.transpose %arg0, [2, 0, 1]
///     : vector<32x16x8xf32> to vector<8x32x16xf32>
///   %1 = vector.contract {indexing_maps = [
///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
///         affine_map<(d0, d1, d2) -> (d0, d1)>],
///    iterator_types = ["parallel", "parallel", "reduction"],
///    kind = add} %0, %arg1, %cst_f0
///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
/// Gets converted to:
/// ```
///   %1 = vector.contract {indexing_maps = [
///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
///         affine_map<(d0, d1, d2) -> (d0, d1)>],
///    iterator_types = ["parallel", "parallel", "reduction"],
///    kind = add} %arg0, %arg1, %cst_f0
///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
///  ```
struct CombineContractABTranspose final
    : public OpRewritePattern<vector::ContractionOp> {};

/// Merges accumulator and result transposes into contract.
///
/// For example:
/// ```mlir
/// %accT = vector.transpose %acc, [0, 2, 1]
///   : vector<2x8x4xf32> to vector<2x4x8xf32>
/// %contract = vector.contract {
///   indexing_maps = [
///     affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
///     affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
///     affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
///   ],
///   iterator_types = ["parallel", "parallel", "parallel", "reduction"],
///   kind = #vector.kind<add>
/// } %lhs, %rhs, %accT
///   : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
/// %0 = vector.transpose %contract, [0, 2, 1]
///   : vector<2x4x8xf32> to vector<2x8x4>
/// ```
/// Becomes:
/// ```mlir
/// %0 = vector.contract {
///   indexing_maps = [
///     affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
///     affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
///     affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
///   ],
///   iterator_types = ["parallel", "parallel", "parallel", "reduction"],
///   kind = #vector.kind<add>
/// } %lhs, %rhs, %acc
///   : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
/// ```
struct CombineContractResultTranspose final
    : public OpRewritePattern<vector::TransposeOp> {};

/// Merge BroadcastOp into ContractionOp user.
/// Ex:
/// ```
///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
///   %1 = vector.contract {indexing_maps = [
///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
///         affine_map<(d0, d1, d2) -> (d0, d1)>],
///    iterator_types = ["parallel", "parallel", "reduction"],
///    kind = add} %0, %arg1, %cst_f0
///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
/// Gets converted to:
/// ```
///   %1 = vector.contract {indexing_maps = [
///         affine_map<(d0, d1, d2) -> (d1, d2)>,
///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
///         affine_map<(d0, d1, d2) -> (d0, d1)>],
///    iterator_types = ["parallel", "parallel", "reduction"],
///    kind = add} %arg0, %arg1, %cst_f0
///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
///  ```
struct CombineContractBroadcast
    : public OpRewritePattern<vector::ContractionOp> {};

/// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
/// contraction ops closer, which kicks in CombineContractBroadcast pattern when
/// casting ops are around these operations.
/// Ex:
/// ```
///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
/// ```
/// Gets converted to:
/// ```
///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
/// ```
struct ReorderCastOpsOnBroadcast
    : public OpInterfaceRewritePattern<CastOpInterface> {};

/// Reorders elementwise(transpose) to transpose(elementwise). This makes
/// transpose ops and contraction ops closer, which kicks in
/// CombineContractABTranspose pattern when elementwise ops are between these
/// operations. Ex:
/// ```
/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
/// %r = arith.addf %at, %bt : vector<2x4xf32>
/// ```
/// Gets converted to:
/// ```
/// %0 = arith.addf %a, %b : vector<4x2xf32>
/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
/// ```
struct ReorderElementwiseOpsOnTranspose final
    : public OpTraitRewritePattern<OpTrait::Elementwise> {};

// Returns the values in `arrayAttr` as an integer vector.
static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {}

// Shuffles vector.bitcast op after vector.extract op.
//
// This transforms IR like:
//   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
//   %1 = vector.extract %0[3] : f16 from vector<8xf16>
// Into:
//   %0 = vector.extract %src[1] : f32 from vector<4xf32>
//   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
//   %2 = vector.extract %1[1] : f16 from vector<2xf16>
struct BubbleDownVectorBitCastForExtract
    : public OpRewritePattern<vector::ExtractOp> {};

// Shuffles vector.bitcast op after vector.extract_strided_slice op.
//
// This transforms IR like:
//    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
//     %0 = vector.extract_strided_slice %cast {
//            offsets = [4], sizes = [4], strides = [1]
//          } : vector<8xf16> to vector<4xf16>
// Into:
//   %0 = vector.extract_strided_slice %src {
//          offsets = [2], sizes = [2], strides = [1]
//        } : vector<4xf32> to vector<2xf32>
//   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
struct BubbleDownBitCastForStridedSliceExtract
    : public OpRewritePattern<vector::ExtractStridedSliceOp> {};

// Shuffles vector.bitcast op before vector.insert_strided_slice op.
//
// This transforms IR like:
//   %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
//   %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
// Into:
//   %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
//   %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
//   %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
//
struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {};

// Shuffles vector.bitcast op before vector.insert_strided_slice op.
//
// This transforms IR like:
//   %0 = vector.insert_strided_slice %src, %dst {
//          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
//   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
// Into:
//   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
//   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
//   %2 = vector.insert_strided_slice %src, %dst {
//          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
struct BubbleUpBitCastForStridedSliceInsert
    : public OpRewritePattern<vector::BitCastOp> {};

// Breaks down vector.bitcast op
//
// This transforms IR like:
//   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
// Into:
//   %cst = vector.splat %c0_f32 : vector<4xf32>
//   %1 = vector.extract_strided_slice %0 {
//          offsets = [0], sizes = [4], strides = [1]
//        } : vector<8xf16> to vector<4xf16>
//   %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
//   %4 = vector.insert_strided_slice %2, %cst {
//          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
//   %5 = vector.extract_strided_slice %0 {
//          offsets = [4], sizes = [4], strides = [1]
//        } : vector<8xf16> to vector<4xf16>
//   %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
//   %7 = vector.insert_strided_slice %6, %cst {
//          offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {};

/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
/// ```
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
/// %r = arith.addi %a, %b : vector<1x4xindex>
/// ```
/// Gets converted to:
/// ```
/// %r = arith.addi %arg0, %arg1 : index
/// %b = vector.broadcast %r : index to vector<1x4xindex>
/// ```
///
/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
/// ops.
struct ReorderElementwiseOpsOnBroadcast final
    : public OpTraitRewritePattern<OpTrait::Elementwise> {};

// Helper that returns a vector comparison that constructs a mask:
//     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
// If `dim == 0` then the result will be a 0-D vector.
//
// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
//       much more compact, IR for this operation, but LLVM eventually
//       generates more elaborate instructions for this intrinsic since it
//       is very conservative on the boundary conditions.
static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
                                   bool force32BitVectorIndices, int64_t dim,
                                   Value b, Value *off = nullptr) {}

template <typename ConcreteOp>
struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {};

/// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
class VectorCreateMaskOpConversion
    : public OpRewritePattern<vector::CreateMaskOp> {};

/// Returns true if all the `i1` elements of `constantOp` are set to `value`.
static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {}

/// Folds a select operation between an all-true and all-false vector. For now,
/// only single element vectors (i.e., vector<1xi1>) are supported. That is:
///
///   %true = arith.constant dense<true> : vector<1xi1>
///   %false = arith.constant dense<false> : vector<1xi1>
///   %result = arith.select %cond, %true, %false : i1, vector<1xi1>
///   =>
///   %result = vector.broadcast %cond : i1 to vector<1xi1>
///
/// InstCombine seems to handle vectors with multiple elements but not the
/// single element ones.
struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {};

/// Returns the number of dims can be folded away from transfer ops. It returns
/// a failure if it can not determine the number of dims to be folded.
///
/// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
/// `vectorType` is vector<16x16x1x1xf32>
/// (there two inner most dims can be dropped by memref.subview ops)
///
/// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
/// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
/// (only the inner most unit dim of `srcType` can be dropped)
///
/// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
/// `vectorType` is vector<16x16x1x[1]xf32>
/// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
/// unit")
static FailureOr<size_t>
getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {}

/// Drop inner most contiguous unit dimensions from transfer_read operand.
class DropInnerMostUnitDimsTransferRead
    : public OpRewritePattern<vector::TransferReadOp> {};

/// Drop inner most contiguous unit dimensions from transfer_write operand.
/// E.g.,
///    vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
///      {in_bounds = [true, true, true, true, true]}
///      : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
///
/// will be replaced with
///
///    %subview = memref.subview %arg0
///      [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
///      : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
///    %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
///      to vector<1x16x16xf32>
///    vector.transfer_write %0, %subview[%c0, %arg2, %c0]
///      {in_bounds = [true, true, true]}
///      : vector<1x16x16xf32>, memref<1x512x16xf32>
///
/// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
class DropInnerMostUnitDimsTransferWrite
    : public OpRewritePattern<vector::TransferWriteOp> {};

/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
/// semantics to a contraction suitable for MMT (matrix matrix multiplication
/// with the RHS transposed) lowering.
struct CanonicalizeContractMatmulToMMT final
    : OpRewritePattern<vector::ContractionOp> {};

/// Pattern to fold arithmetic extensions on floating point data types into
/// vector contraction operations. linalg.matmul introduces arithmetic
/// extensions on its operands. Please mlir snippets below for more details.
/// ```mlir
///   "linalg.matmul"(%lhs, %rhs, %acc) ({
///      ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
///        %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
///        %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
///        %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
///        %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
///        "linalg.yield"(%acc) : (f32) -> ()
///     })
/// ```
/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
/// This pattern folds the arithmetic extensions into the vector contraction and
/// enables the usage of native mixed precision Tensor Core instructions.
template <typename ExtOp>
struct FoldArithExtIntoContractionOp
    : public OpRewritePattern<vector::ContractionOp> {};

/// Pattern to fold chained reduction to a series of vector additions and a
/// final reduction. This form should require fewer subgroup operations.
///
/// ```mlir
/// %a = vector.reduction <add> %x, %acc
/// %b = vector.reduction <add> %y, %a
///  ==>
/// %a = arith.addf %x, %y
/// %b = vector.reduction <add> %a, %acc
/// ```
struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {};

// Helper function dropping unit non-scalable dimension from a VectorType
// keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
// dimensions are not dropped. Folding such dimensions would require "shifting"
// the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
// vector<[4]xf32>). This could be implemented in the future.
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {}

/// For vectors with at least one unit dim, replaces:
///   elementwise(a, b)
/// with:
///   sc_a = shape_cast(a)
///   sc_b = shape_cast(b)
///   res = elementwise(sc_a, sc_b)
///   return shape_cast(res)
/// The newly inserted shape_cast Ops fold (before elementwise Op) and then
/// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
/// required to be rank > 1.
///
/// Ex:
///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
///
/// gets converted to:
///
///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
///  %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
///  %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
///
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
/// `%cast`.
struct DropUnitDimFromElementwiseOps final
    : public OpTraitRewritePattern<OpTrait::Elementwise> {};

/// A pattern to drop unit dims from vector.transpose.
///
/// Example:
///
///  BEFORE:
///  ```mlir
///  %transpose = vector.transpose %vector, [3, 0, 1, 2]
///    : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
///  ```
///
///  AFTER:
///  ```mlir
///  %dropDims = vector.shape_cast %vector
///    : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
///  %transpose = vector.transpose %0, [1, 0]
///    : vector<4x[4]xf32> to vector<[4]x4xf32>
///  %restoreDims = vector.shape_cast %transpose
///    : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
///  ```
struct DropUnitDimsFromTransposeOp final
    : OpRewritePattern<vector::TransposeOp> {};

/// A pattern to drop unit dims from the iter_args of an scf.for.
///
/// Example:
///
///  BEFORE:
///  ```mlir
///  %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
///    ...
///    scf.yield %
///  }
///  ```
///
///  AFTER:
///  ```mlir
///  %drop = vector.shape_cast %init
///    : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
///  %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
///    %new_iter = vector.shape_cast %iter
///      : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
///    ...
///  }
///  %res = vector.shape_cast %new_loop
///    : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
///  ```
struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {};

/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
/// `ChainedReduction` pattern.
///
/// ```mlir
/// %a = arith.addf %x, %zero
/// %b = arith.addf %a, %y
/// %c = vector.reduction <add> %b, %acc
///  ==>
/// %b = arith.addf %a, %y
/// %c = vector.reduction <add> %b, %acc
/// ```
struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {};

/// Example:
/// ```
/// %a = vector.reduction <add> %x : vector<2xf32> into f32
/// ```
/// is transformed into:
/// ```
/// %y = vector.extract %x[0] : f32 from vector<2xf32>
/// %z = vector.extract %x[1] : f32 from vector<2xf32>
/// %a = arith.addf %y, %z : f32
/// ```
struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {};

/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
/// B)`.
/// Example:
///  %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
///  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
///  vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
///  vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
///
/// Becomes :
///
///  %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
///
/// Supports only 1D-to-2D broadcasts. The following cases are not supported.
/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
template <typename MulOpType>
struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {};

} // namespace

void mlir::vector::populateFoldArithExtensionPatterns(
    RewritePatternSet &patterns) {}

void mlir::vector::populateVectorMaskMaterializationPatterns(
    RewritePatternSet &patterns, bool force32BitVectorIndices,
    PatternBenefit benefit) {}

void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
                                                    PatternBenefit benefit) {}

void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {}

void mlir::vector::populateBubbleVectorBitCastOpPatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {}

void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
    RewritePatternSet &patterns,
    std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {}

void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
    RewritePatternSet &patterns,
    std::function<LogicalResult(vector::ContractionOp)> constraint,
    PatternBenefit benefit) {}

void mlir::vector::populateVectorReductionToContractPatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {}

void mlir::vector::
    populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
        RewritePatternSet &patterns, PatternBenefit benefit) {}

void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit) {}

void mlir::vector::populateChainedVectorReductionFoldingPatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {}

void mlir::vector::populateBreakDownVectorReductionPatterns(
    RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
    PatternBenefit benefit) {}

void mlir::vector::populateElementwiseToVectorOpsPatterns(
    RewritePatternSet &patterns) {}

//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"