//===- VectorUtils.h - Vector Utilities -------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_ #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { // Forward declarations. class AffineMap; class Block; class Location; class OpBuilder; class Operation; class ShapedType; class Value; class VectorType; class VectorTransferOpInterface; namespace affine { class AffineApplyOp; class AffineForOp; } // namespace affine namespace vector { /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); /// Returns two dims that are greater than one if the transposition is applied /// on a 2D slice. Otherwise, returns a failure. FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// /// Only the N = vectorType.getRank() trailing dims of `memrefType` are /// checked (the other dims are not relevant). Note that for `vectorType` to be /// a contiguous slice of `memrefType`, the trailing dims of the latter have /// to be contiguous - this is checked by looking at the corresponding strides. /// /// There might be some restriction on the leading dim of `VectorType`: /// /// Case 1. If all the trailing dims of `vectorType` match the trailing dims /// of `memrefType` then the leading dim of `vectorType` can be /// arbitrary. /// /// Ex. 1.1 contiguous slice, perfect match /// vector<4x3x2xi32> from memref<5x4x3x2xi32> /// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) /// vector<2x3x2xi32> from memref<5x4x3x2xi32> /// /// Case 2. If an "internal" dim of `vectorType` does not match the /// corresponding trailing dim in `memrefType` then the remaining /// leading dims of `vectorType` have to be 1 (the first non-matching /// dim can be arbitrary). /// /// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> /// vector<2x2x2xi32> from memref<5x4x3x2xi32> /// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> /// vector<1x2x2xi32> from memref<5x4x3x2xi32> /// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> /// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> /// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> /// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) bool isContiguousSlice(MemRefType memrefType, VectorType vectorType); /// Returns an iterator for all positions in the leading dimensions of `vType` /// up to the `targetRank`. If any leading dimension before the `targetRank` is /// scalable (so cannot be unrolled), it will return an iterator for positions /// up to the first scalable dimension. /// /// If no leading dimensions can be unrolled an empty optional will be returned. /// /// Examples: /// /// For vType = vector<2x3x4> and targetRank = 1 /// /// The resulting iterator will yield: /// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2] /// /// For vType = vector<3x[4]x5> and targetRank = 0 /// /// The scalable dimension blocks unrolling so the iterator yields only: /// [0], [1], [2] /// std::optional<StaticTileOffsetRange> createUnrollIterator(VectorType vType, int64_t targetRank = 1); /// Returns a functor (int64_t -> Value) which returns a constant vscale /// multiple. /// /// Example: /// ```c++ /// auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc); /// auto c4Vscale = createVscaleMultiple(4); // 4 * vector.vscale /// ``` inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) { … } /// Returns a range over the dims (size and scalability) of a VectorType. inline auto getDims(VectorType vType) { … } /// A wrapper for getMixedSizes for vector.transfer_read and /// vector.transfer_write Ops (for source and destination, respectively). /// /// Tensor and MemRef types implement their own, very similar version of /// getMixedSizes. This method will call the appropriate version (depending on /// `hasTensorSemantics`). It will also automatically extract the operand for /// which to call it on (source for "read" and destination for "write" ops). SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter); /// A pattern for ops that implement `MaskableOpInterface` and that _might_ be /// masked (i.e. inside `vector.mask` Op region). In particular: /// 1. Matches `SourceOp` operation, Op. /// 2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the /// insertion point to avoid inserting new ops into the `vector.mask` Op /// region (which only allows one Op). /// 2.2 If Op is not masked, this step is skipped. /// 3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if /// found in step 2.1. /// /// This wrapper frees patterns from re-implementing the logic to update the /// insertion point when a maskable Op is masked. Such patterns are still /// responsible for providing an updated ("rewritten") version of: /// a. the source Op when mask _is not_ present, /// b. the source Op and the masking Op when mask _is_ present. /// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that /// the return value will depend on the case above. template <class SourceOp> struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> { … }; /// Returns true if the input Vector type can be linearized. /// /// Linearization is meant in the sense of flattening vectors, e.g.: /// * vector<NxMxKxi32> -> vector<N*M*Kxi32> /// In this sense, Vectors that are either: /// * already linearized, or /// * contain more than 1 scalable dimensions, /// are not linearizable. bool isLinearizableVector(VectorType type); /// Create a TransferReadOp from `source` with static shape `readShape`. If the /// vector type for the read is not the same as the type of `source`, then a /// mask is created on the read, if use of mask is specified or the bounds on a /// dimension are different. /// /// `useInBoundsInsteadOfMasking` if false, the inBoundsVal values are set /// properly, based on /// the rank dimensions of the source and destination tensors. And that is /// what determines if masking is done. /// /// Note that the internal `vector::TransferReadOp` always read at indices zero /// for each dimension of the passed in tensor. Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef<int64_t> readShape, Value padValue, bool useInBoundsInsteadOfMasking); /// Returns success if `inputVectorSizes` is a valid masking configuraion for /// given `shape`, i.e., it meets: /// 1. The numbers of elements in both array are equal. /// 2. `inputVectorSizes` does not have dynamic dimensions. /// 3. All the values in `inputVectorSizes` are greater than or equal to /// static sizes in `shape`. LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape, ArrayRef<int64_t> inputVectorSizes); } // namespace vector /// Constructs a permutation map of invariant memref indices to vector /// dimension. /// /// If no index is found to be invariant, 0 is added to the permutation_map and /// corresponds to a vector broadcast along that dimension. /// /// The implementation uses the knowledge of the mapping of loops to /// vector dimension. `loopToVectorDim` carries this information as a map with: /// - keys representing "vectorized enclosing loops"; /// - values representing the corresponding vector dimension. /// Note that loopToVectorDim is a whole function map from which only enclosing /// loop information is extracted. /// /// Prerequisites: `indices` belong to a vectorizable load or store operation /// (i.e. at most one invariant index along each AffineForOp of /// `loopToVectorDim`). `insertPoint` is the insertion point for the vectorized /// load or store operation. /// /// Example 1: /// The following MLIR snippet: /// /// ```mlir /// affine.for %i3 = 0 to %0 { /// affine.for %i4 = 0 to %1 { /// affine.for %i5 = 0 to %2 { /// %a5 = load %arg0[%i4, %i5, %i3] : memref<?x?x?xf32> /// }}} /// ``` /// /// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into: /// /// ```mlir /// affine.for %i3 = 0 to %0 step 32 { /// affine.for %i4 = 0 to %1 { /// affine.for %i5 = 0 to %2 step 256 { /// %4 = vector.transfer_read %arg0, %i4, %i5, %i3 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} : /// (memref<?x?x?xf32>, index, index) -> vector<32x256xf32> /// }}} /// ``` /// /// Meaning that vector.transfer_read will be responsible for reading the slice: /// `%arg0[%i4, %i5:%15+256, %i3:%i3+32]` into vector<32x256xf32>. /// /// Example 2: /// The following MLIR snippet: /// /// ```mlir /// %cst0 = arith.constant 0 : index /// affine.for %i0 = 0 to %0 { /// %a0 = load %arg0[%cst0, %cst0] : memref<?x?xf32> /// } /// ``` /// /// may vectorize with {permutation_map: (d0) -> (0)} into: /// /// ```mlir /// affine.for %i0 = 0 to %0 step 128 { /// %3 = vector.transfer_read %arg0, %c0_0, %c0_0 /// {permutation_map: (d0, d1) -> (0)} : /// (memref<?x?xf32>, index, index) -> vector<128xf32> /// } /// ```` /// /// Meaning that vector.transfer_read will be responsible of reading the slice /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. /// AffineMap makePermutationMap(Block *insertPoint, ArrayRef<Value> indices, const DenseMap<Operation *, unsigned> &loopToVectorDim); AffineMap makePermutationMap(Operation *insertPoint, ArrayRef<Value> indices, const DenseMap<Operation *, unsigned> &loopToVectorDim); namespace matcher { /// Matches vector.transfer_read, vector.transfer_write and ops that return a /// vector type that is a multiple of the sub-vector type. This allows passing /// over other smaller vector types in the function and avoids interfering with /// operations on those. /// This is a first approximation, it can easily be extended in the future. /// TODO: this could all be much simpler if we added a bit that a vector type to /// mark that a vector is a strict super-vector but it still does not warrant /// adding even 1 extra bit in the IR for now. bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType); } // namespace matcher } // namespace mlir #endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_