//===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites as 1->N patterns. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include <cassert> #include <cstdint> #include <functional> #include <optional> #include <type_traits> #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::vector; template <typename IntType> static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) { … } // Helper to find an index in an affine map. static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { … } namespace { /// ShapeCastOpFolder folds cancelling ShapeCastOps away. // // Example: // // The following MLIR with cancelling ShapeCastOps: // // %0 = source : vector<5x4x2xf32> // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> // %3 = user %2 : vector<5x4x2xf32> // // Should canonicalize to the following: // // %0 = source : vector<5x4x2xf32> // %1 = user %0 : vector<5x4x2xf32> // struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { … }; /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. /// Ex: /// ``` /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> /// %1 = vector.multi_reduction add, %0 [1] /// : vector<8x32x16xf32> to vector<8x16xf32> /// ``` /// Gets converted to: /// ``` /// %1 = vector.contract {indexing_maps = [ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], /// kind = add} %0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` struct MultiReduceToContract : public OpRewritePattern<vector::MultiDimReductionOp> { … }; /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user. /// Ex: /// ``` /// %0 = vector.transpose %arg0, [2, 0, 1] /// : vector<32x16x8xf32> to vector<8x32x16xf32> /// %1 = vector.contract {indexing_maps = [ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], /// kind = add} %0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` /// Gets converted to: /// ``` /// %1 = vector.contract {indexing_maps = [ /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], /// kind = add} %arg0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` struct CombineContractABTranspose final : public OpRewritePattern<vector::ContractionOp> { … }; /// Merges accumulator and result transposes into contract. /// /// For example: /// ```mlir /// %accT = vector.transpose %acc, [0, 2, 1] /// : vector<2x8x4xf32> to vector<2x4x8xf32> /// %contract = vector.contract { /// indexing_maps = [ /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> /// ], /// iterator_types = ["parallel", "parallel", "parallel", "reduction"], /// kind = #vector.kind<add> /// } %lhs, %rhs, %accT /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32> /// %0 = vector.transpose %contract, [0, 2, 1] /// : vector<2x4x8xf32> to vector<2x8x4> /// ``` /// Becomes: /// ```mlir /// %0 = vector.contract { /// indexing_maps = [ /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, /// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> /// ], /// iterator_types = ["parallel", "parallel", "parallel", "reduction"], /// kind = #vector.kind<add> /// } %lhs, %rhs, %acc /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32> /// ``` struct CombineContractResultTranspose final : public OpRewritePattern<vector::TransposeOp> { … }; /// Merge BroadcastOp into ContractionOp user. /// Ex: /// ``` /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> /// %1 = vector.contract {indexing_maps = [ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], /// kind = add} %0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` /// Gets converted to: /// ``` /// %1 = vector.contract {indexing_maps = [ /// affine_map<(d0, d1, d2) -> (d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel", "reduction"], /// kind = add} %arg0, %arg1, %cst_f0 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` struct CombineContractBroadcast : public OpRewritePattern<vector::ContractionOp> { … }; /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and /// contraction ops closer, which kicks in CombineContractBroadcast pattern when /// casting ops are around these operations. /// Ex: /// ``` /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> /// ``` /// Gets converted to: /// ``` /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> /// ``` struct ReorderCastOpsOnBroadcast : public OpInterfaceRewritePattern<CastOpInterface> { … }; /// Reorders elementwise(transpose) to transpose(elementwise). This makes /// transpose ops and contraction ops closer, which kicks in /// CombineContractABTranspose pattern when elementwise ops are between these /// operations. Ex: /// ``` /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> /// %r = arith.addf %at, %bt : vector<2x4xf32> /// ``` /// Gets converted to: /// ``` /// %0 = arith.addf %a, %b : vector<4x2xf32> /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> /// ``` struct ReorderElementwiseOpsOnTranspose final : public OpTraitRewritePattern<OpTrait::Elementwise> { … }; // Returns the values in `arrayAttr` as an integer vector. static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) { … } // Shuffles vector.bitcast op after vector.extract op. // // This transforms IR like: // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> // %1 = vector.extract %0[3] : f16 from vector<8xf16> // Into: // %0 = vector.extract %src[1] : f32 from vector<4xf32> // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> // %2 = vector.extract %1[1] : f16 from vector<2xf16> struct BubbleDownVectorBitCastForExtract : public OpRewritePattern<vector::ExtractOp> { … }; // Shuffles vector.bitcast op after vector.extract_strided_slice op. // // This transforms IR like: // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> // %0 = vector.extract_strided_slice %cast { // offsets = [4], sizes = [4], strides = [1] // } : vector<8xf16> to vector<4xf16> // Into: // %0 = vector.extract_strided_slice %src { // offsets = [2], sizes = [2], strides = [1] // } : vector<4xf32> to vector<2xf32> // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> struct BubbleDownBitCastForStridedSliceExtract : public OpRewritePattern<vector::ExtractStridedSliceOp> { … }; // Shuffles vector.bitcast op before vector.insert_strided_slice op. // // This transforms IR like: // %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4> // %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8> // Into: // %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8> // %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8> // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8> // struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { … }; // Shuffles vector.bitcast op before vector.insert_strided_slice op. // // This transforms IR like: // %0 = vector.insert_strided_slice %src, %dst { // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> // Into: // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> // %2 = vector.insert_strided_slice %src, %dst { // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> struct BubbleUpBitCastForStridedSliceInsert : public OpRewritePattern<vector::BitCastOp> { … }; // Breaks down vector.bitcast op // // This transforms IR like: // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> // Into: // %cst = vector.splat %c0_f32 : vector<4xf32> // %1 = vector.extract_strided_slice %0 { // offsets = [0], sizes = [4], strides = [1] // } : vector<8xf16> to vector<4xf16> // %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32> // %4 = vector.insert_strided_slice %2, %cst { // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> // %5 = vector.extract_strided_slice %0 { // offsets = [4], sizes = [4], strides = [1] // } : vector<8xf16> to vector<4xf16> // %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32> // %7 = vector.insert_strided_slice %6, %cst { // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> { … }; /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: /// ``` /// %a = vector.broadcast %arg1 : index to vector<1x4xindex> /// %b = vector.broadcast %arg2 : index to vector<1x4xindex> /// %r = arith.addi %a, %b : vector<1x4xindex> /// ``` /// Gets converted to: /// ``` /// %r = arith.addi %arg0, %arg1 : index /// %b = vector.broadcast %r : index to vector<1x4xindex> /// ``` /// /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting /// ops. struct ReorderElementwiseOpsOnBroadcast final : public OpTraitRewritePattern<OpTrait::Elementwise> { … }; // Helper that returns a vector comparison that constructs a mask: // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] // // If `dim == 0` then the result will be a 0-D vector. // // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, // much more compact, IR for this operation, but LLVM eventually // generates more elaborate instructions for this intrinsic since it // is very conservative on the boundary conditions. static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, bool force32BitVectorIndices, int64_t dim, Value b, Value *off = nullptr) { … } template <typename ConcreteOp> struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { … }; /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). class VectorCreateMaskOpConversion : public OpRewritePattern<vector::CreateMaskOp> { … }; /// Returns true if all the `i1` elements of `constantOp` are set to `value`. static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) { … } /// Folds a select operation between an all-true and all-false vector. For now, /// only single element vectors (i.e., vector<1xi1>) are supported. That is: /// /// %true = arith.constant dense<true> : vector<1xi1> /// %false = arith.constant dense<false> : vector<1xi1> /// %result = arith.select %cond, %true, %false : i1, vector<1xi1> /// => /// %result = vector.broadcast %cond : i1 to vector<1xi1> /// /// InstCombine seems to handle vectors with multiple elements but not the /// single element ones. struct FoldI1Select : public OpRewritePattern<arith::SelectOp> { … }; /// Returns the number of dims can be folded away from transfer ops. It returns /// a failure if it can not determine the number of dims to be folded. /// /// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and /// `vectorType` is vector<16x16x1x1xf32> /// (there two inner most dims can be dropped by memref.subview ops) /// /// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with /// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32> /// (only the inner most unit dim of `srcType` can be dropped) /// /// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and /// `vectorType` is vector<16x16x1x[1]xf32> /// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable /// unit") static FailureOr<size_t> getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { … } /// Drop inner most contiguous unit dimensions from transfer_read operand. class DropInnerMostUnitDimsTransferRead : public OpRewritePattern<vector::TransferReadOp> { … }; /// Drop inner most contiguous unit dimensions from transfer_write operand. /// E.g., /// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] /// {in_bounds = [true, true, true, true, true]} /// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32> /// /// will be replaced with /// /// %subview = memref.subview %arg0 /// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1] /// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32> /// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32> /// to vector<1x16x16xf32> /// vector.transfer_write %0, %subview[%c0, %arg2, %c0] /// {in_bounds = [true, true, true]} /// : vector<1x16x16xf32>, memref<1x512x16xf32> /// /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`). class DropInnerMostUnitDimsTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { … }; /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul /// semantics to a contraction suitable for MMT (matrix matrix multiplication /// with the RHS transposed) lowering. struct CanonicalizeContractMatmulToMMT final : OpRewritePattern<vector::ContractionOp> { … }; /// Pattern to fold arithmetic extensions on floating point data types into /// vector contraction operations. linalg.matmul introduces arithmetic /// extensions on its operands. Please mlir snippets below for more details. /// ```mlir /// "linalg.matmul"(%lhs, %rhs, %acc) ({ /// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32): /// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32 /// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32 /// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32 /// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32 /// "linalg.yield"(%acc) : (f32) -> () /// }) /// ``` /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`. /// This pattern folds the arithmetic extensions into the vector contraction and /// enables the usage of native mixed precision Tensor Core instructions. template <typename ExtOp> struct FoldArithExtIntoContractionOp : public OpRewritePattern<vector::ContractionOp> { … }; /// Pattern to fold chained reduction to a series of vector additions and a /// final reduction. This form should require fewer subgroup operations. /// /// ```mlir /// %a = vector.reduction <add> %x, %acc /// %b = vector.reduction <add> %y, %a /// ==> /// %a = arith.addf %x, %y /// %b = vector.reduction <add> %a, %acc /// ``` struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> { … }; // Helper function dropping unit non-scalable dimension from a VectorType // keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit // dimensions are not dropped. Folding such dimensions would require "shifting" // the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> -> // vector<[4]xf32>). This could be implemented in the future. static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) { … } /// For vectors with at least one unit dim, replaces: /// elementwise(a, b) /// with: /// sc_a = shape_cast(a) /// sc_b = shape_cast(b) /// res = elementwise(sc_a, sc_b) /// return shape_cast(res) /// The newly inserted shape_cast Ops fold (before elementwise Op) and then /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are /// required to be rank > 1. /// /// Ex: /// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32> /// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32> /// /// gets converted to: /// /// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32> /// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32> /// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32> /// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32> /// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32> /// /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and /// `%cast`. struct DropUnitDimFromElementwiseOps final : public OpTraitRewritePattern<OpTrait::Elementwise> { … }; /// A pattern to drop unit dims from vector.transpose. /// /// Example: /// /// BEFORE: /// ```mlir /// %transpose = vector.transpose %vector, [3, 0, 1, 2] /// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> /// ``` /// /// AFTER: /// ```mlir /// %dropDims = vector.shape_cast %vector /// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> /// %transpose = vector.transpose %0, [1, 0] /// : vector<4x[4]xf32> to vector<[4]x4xf32> /// %restoreDims = vector.shape_cast %transpose /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> /// ``` struct DropUnitDimsFromTransposeOp final : OpRewritePattern<vector::TransposeOp> { … }; /// A pattern to drop unit dims from the iter_args of an scf.for. /// /// Example: /// /// BEFORE: /// ```mlir /// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> { /// ... /// scf.yield % /// } /// ``` /// /// AFTER: /// ```mlir /// %drop = vector.shape_cast %init /// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32> /// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> { /// %new_iter = vector.shape_cast %iter /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> /// ... /// } /// %res = vector.shape_cast %new_loop /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> /// ``` struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> { … }; /// Pattern to eliminate redundant zero-constants added to reduction operands. /// It's enough for there to be one initial zero value, so we can eliminate the /// extra ones that feed into `vector.reduction <add>`. These get created by the /// `ChainedReduction` pattern. /// /// ```mlir /// %a = arith.addf %x, %zero /// %b = arith.addf %a, %y /// %c = vector.reduction <add> %b, %acc /// ==> /// %b = arith.addf %a, %y /// %c = vector.reduction <add> %b, %acc /// ``` struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> { … }; /// Example: /// ``` /// %a = vector.reduction <add> %x : vector<2xf32> into f32 /// ``` /// is transformed into: /// ``` /// %y = vector.extract %x[0] : f32 from vector<2xf32> /// %z = vector.extract %x[1] : f32 from vector<2xf32> /// %a = arith.addf %y, %z : f32 /// ``` struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> { … }; /// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A, /// B)`. /// Example: /// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32> /// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to /// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to /// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32> /// /// Becomes : /// /// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32> /// /// Supports only 1D-to-2D broadcasts. The following cases are not supported. /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32> /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32> /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32> template <typename MulOpType> struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> { … }; } // namespace void mlir::vector::populateFoldArithExtensionPatterns( RewritePatternSet &patterns) { … } void mlir::vector::populateVectorMaskMaterializationPatterns( RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit) { … } void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { … } void mlir::vector::populateDropUnitDimWithShapeCastPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … } void mlir::vector::populateBubbleVectorBitCastOpPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … } void mlir::vector::populateBreakDownVectorBitCastOpPatterns( RewritePatternSet &patterns, std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) { … } void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT( RewritePatternSet &patterns, std::function<LogicalResult(vector::ContractionOp)> constraint, PatternBenefit benefit) { … } void mlir::vector::populateVectorReductionToContractPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … } void mlir::vector:: populateVectorTransferCollapseInnerMostContiguousDimsPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … } void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { … } void mlir::vector::populateChainedVectorReductionFoldingPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … } void mlir::vector::populateBreakDownVectorReductionPatterns( RewritePatternSet &patterns, unsigned maxNumElementsToExtract, PatternBenefit benefit) { … } void mlir::vector::populateElementwiseToVectorOpsPatterns( RewritePatternSet &patterns) { … } //===----------------------------------------------------------------------===// // TableGen'd enum attribute definitions //===----------------------------------------------------------------------===// #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"