//===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites and utilities to lower the // 'vector.contract' operation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/VectorInterfaces.h" #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::vector; //===----------------------------------------------------------------------===// // Helper functions //===----------------------------------------------------------------------===// // Helper to find an index in an affine map. static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { … } // Helper to construct iterator types with one index removed. static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes, int64_t index) { … } // Helper to construct an affine map with one index removed. static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter) { … } // Helper method to possibly drop a dimension in a load. // TODO static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter) { … } // Helper method to possibly drop a dimension in a store. // TODO static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter) { … } /// Helper to create arithmetic operation associated with a kind of contraction. static std::optional<Value> createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask = Value()) { … } /// Return the positions of the reductions in the given map. static SmallVector<int64_t> getReductionIndex(AffineMap map, ArrayAttr iteratorTypes) { … } /// Look for a given dimension in an affine map and return its position. Return /// std::nullopt if the dimension is not in the map results. static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) { … } /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using /// operands `x` and `y`. static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter) { … } /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using /// operands `x and `y`. static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter) { … } namespace { /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul /// semantics to: /// ``` /// %flattened_a = vector.shape_cast %a /// %flattened_b = vector.shape_cast %b /// %flattened_d = vector.matmul %flattened_a, %flattened_b /// %d = vector.shape_cast %%flattened_d /// %e = add %c, %d /// ``` /// `vector.matmul` later lowers to `llvm.matrix.multiply`. // /// This only kicks in when VectorTransformsOptions is set to OuterProduct and /// the vector.contract op is a row-major matrix multiply. class ContractionOpToMatmulOpLowering : public vector::MaskableOpRewritePattern<vector::ContractionOp> { … }; /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul /// semantics to a reduction_size-unrolled sequence: /// ``` /// %at = vector.transpose %a, [1, 0] /// %bRow0 = vector.extract %b[0] /// %atRow0 = vector.extract %at[0] /// %c0 = vector.outerproduct %atRow0, %bRow0, %c /// ... /// %bRowK = vector.extract %b[K] /// %atRowK = vector.extract %at[K] /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 /// ``` /// /// This only kicks in when VectorTransformsOptions is set to OuterProduct and /// the vector.contract op is a row-major matrix multiply. class ContractionOpToOuterProductOpLowering : public MaskableOpRewritePattern<vector::ContractionOp> { … }; /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul /// semantics to an output-size-unrolled sequence: /// ``` /// %out = arith.constant ... : vector<MxNxelt_type> /// %bt = vector.transpose %b, [1, 0] /// %aRow0 = vector.extract %a[0] /// %btRow0 = vector.extract %bt[0] /// %c00 = vector.reduce %atRow0, %bRow0 /// %out00 = vector.insert %c00, %out[0, 0] /// ... /// %aRowLast = vector.extract %at[M-1] /// %btRowLast = vector.extract %b[N-1] /// %cLastLast = vector.reduce %atRowLast, %bRowLast /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] /// ``` /// /// This only kicks in when VectorTransformsOptions is set to Dot and /// the vector.contract op is a row-major matmul or matvec. class ContractionOpToDotLowering : public MaskableOpRewritePattern<vector::ContractionOp> { … }; /// Progressive lowering of ContractionOp. /// /// One: /// %x = vector.contract with at least one free/batch dimension /// is replaced by: /// %a = vector.contract with one less free/batch dimension /// %b = vector.contract with one less free/batch dimension /// .. /// %x = combine %a %b .. /// until a pure contraction is reached (no free/batch dimensions), /// which is replaced by a dot-product. /// /// This only kicks in when either VectorTransformsOptions is set /// to Dot or when other contraction patterns fail. class ContractionOpLowering : public MaskableOpRewritePattern<vector::ContractionOp> { … }; /// Generate a vector implementation for matmat, matvec and tmatvec. /// This unrolls outer-products along the reduction dimension. struct UnrolledOuterProductGenerator : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> { … }; /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul /// semantics to a reduction_size-unrolled sequence: /// ``` /// %at = vector.transpose %a, [1, 0] /// %bRow0 = vector.extract %b[0] /// %atRow0 = vector.extract %at[0] /// %c0 = vector.outerproduct %atRow0, %bRow0, %c /// ... /// %bRowK = vector.extract %b[K] /// %atRowK = vector.extract %at[K] /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 /// ``` /// /// This only kicks in when VectorTransformsOptions is set to OuterProduct but /// otherwise supports any layout permutation of the matrix-multiply. FailureOr<Value> ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp( vector::ContractionOp op, MaskingOpInterface maskOp, PatternRewriter &rewriter) const { … } FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp( vector::ContractionOp op, MaskingOpInterface maskOp, PatternRewriter &rewriter) const { … } /// Lower vector.contract with all size one reduction dimensions to /// elementwise ops when possible. struct ContractOpToElementwise : public MaskableOpRewritePattern<vector::ContractionOp> { … }; /// Progressive lowering of ContractionOp. /// One: /// %x = vector.contract with at least one free/batch dimension /// is replaced by: /// %a = vector.contract with one less free/batch dimension /// %b = vector.contract with one less free/batch dimension /// .. /// %x = combine %a %b .. /// until a pure contraction is reached (no free/batch dimensions), /// which is replaced by a dot-product. /// /// This only kicks in when either VectorTransformsOptions is set /// to DOT or when other contraction patterns fail. // // TODO: break down into transpose/reshape/cast ops // when they become available to avoid code dup // TODO: investigate lowering order impact on performance FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp( vector::ContractionOp op, MaskingOpInterface maskOp, PatternRewriter &rewriter) const { … } // Lower one parallel dimension. // Incidentally also tolerates unit-size (hence trivial) reduction dimensions. // TODO: consider reusing existing contract unrolling FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, vector::ContractionOp op, int64_t lhsIndex, int64_t rhsIndex, Value mask) const { … } // Lower one reduction dimension. FailureOr<Value> ContractionOpLowering::lowerReduction( PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const { … } /// Progressive lowering of OuterProductOp. /// One: /// %x = vector.outerproduct %lhs, %rhs, %acc /// is replaced by: /// %z = zero-result /// %0 = vector.extract %lhs[0] /// %1 = vector.broadcast %0 /// %2 = vector.extract %acc[0] /// %3 = vector.fma %1, %rhs, %2 /// %4 = vector.insert %3, %z[0] /// .. /// %x = vector.insert %.., %..[N-1] /// class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { … }; /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul /// semantics to: /// ``` /// %mta = maybe_transpose /// %mtb = maybe_transpose /// %flattened_a = vector.shape_cast %mta /// %flattened_b = vector.shape_cast %mtb /// %flattened_d = vector.matmul %flattened_a, %flattened_b /// %mtd = vector.shape_cast %flattened_d /// %d = maybe_untranspose %mtd /// %e = add %c, %d /// ``` /// `vector.matmul` later lowers to `llvm.matrix.multiply`. // /// This only kicks in when VectorTransformsOptions is set to `Matmul`. /// vector.transpose operations are inserted if the vector.contract op is not a /// row-major matrix multiply. /// /// Scalable vectors are not supported. FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( vector::ContractionOp op, MaskingOpInterface maskOp, PatternRewriter &rew) const { … } } // namespace void mlir::vector::populateVectorContractLoweringPatterns( RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit, bool disableOuterProductLowering) { … } void mlir::vector::populateVectorOuterProductLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … }