//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++ //-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include <cstdint> usingnamespacemlir; #define DEBUG_TYPE … #define DBGS() … #define DBGSNL() … #define LDBG(X) … /// Returns a compressed mask. The mask value is set only if any mask is present /// in the scale range. E.g., if `scale` equals to 2, the following mask: /// /// %mask = [1, 1, 1, 0, 0, 0] /// /// will return the following new compressed mask: /// /// %mask = [1, 1, 0] static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int origElements, int scale) { … } namespace { //===----------------------------------------------------------------------===// // ConvertVectorStore //===----------------------------------------------------------------------===// struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { … }; //===----------------------------------------------------------------------===// // ConvertVectorMaskedStore //===----------------------------------------------------------------------===// struct ConvertVectorMaskedStore final : OpConversionPattern<vector::MaskedStoreOp> { … }; //===----------------------------------------------------------------------===// // ConvertVectorLoad //===----------------------------------------------------------------------===// struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> { … }; //===----------------------------------------------------------------------===// // ConvertVectorMaskedLoad //===----------------------------------------------------------------------===// struct ConvertVectorMaskedLoad final : OpConversionPattern<vector::MaskedLoadOp> { … }; //===----------------------------------------------------------------------===// // ConvertVectorTransferRead //===----------------------------------------------------------------------===// struct ConvertVectorTransferRead final : OpConversionPattern<vector::TransferReadOp> { … }; } // end anonymous namespace //===----------------------------------------------------------------------===// // RewriteBitCastOfTruncI //===----------------------------------------------------------------------===// namespace { /// Helper struct to keep track of the provenance of a contiguous set of bits /// in a source vector. struct SourceElementRange { … }; struct SourceElementRangeList : public SmallVector<SourceElementRange> { … }; /// Helper struct to enumerate the source elements and bit ranges that are /// involved in a bitcast operation. /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for /// any 1-D vector shape and any source/target bitwidths. /// This creates and holds a mapping of the form: /// [dstVectorElementJ] == /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ] /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as: /// [0] = {0, [0-8)} /// [1] = {0, [8-16)} /// [2] = {0, [16-24)} /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as: /// [0] = {0, [0, 10)}, {1, [0, 5)} /// [1] = {1, [5, 10)}, {2, [0, 10)} struct BitCastBitsEnumerator { … }; /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take /// advantage of high-level information to avoid leaving LLVM to scramble with /// peephole optimizations. /// BitCastBitsEnumerator encodes for each element of the target vector the /// provenance of the bits in the source vector. We can "transpose" this /// information to build a sequence of shuffles and bitwise ops that will /// produce the desired result. // /// Consider the following motivating example: /// ``` /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8> /// ``` // /// BitCastBitsEnumerator contains the following information: /// ``` /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5} /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7} /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4} /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6} /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3} /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5} /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7} /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4} /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6} /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3} /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5} /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7} /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4} /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6} /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3} /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5} /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7} /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4} /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6} /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3} /// ``` /// /// In the above, each row represents one target vector element and each /// column represents one bit contribution from a source vector element. /// The algorithm creates vector.shuffle operations (in this case there are 3 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The /// algorithm populates the bits as follows: /// ``` /// src bits 0 ... /// 1st shuffle |xxxxx |xx |... /// 2nd shuffle | xxx| xxxxx |... /// 3rd shuffle | | x|... /// ``` // /// The algorithm proceeds as follows: /// 1. for each vector.shuffle, collect the source vectors that participate in /// this shuffle. One source vector per target element of the resulting /// vector.shuffle. If there is no source element contributing bits for the /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only /// 2 columns). /// 2. represent the bitrange in the source vector as a mask. If there is no /// source element contributing bits for the current vector.shuffle, take 0. /// 3. shift right by the proper amount to align the source bitrange at /// position 0. This is exactly the low end of the bitrange. For instance, /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to /// shift right by 3 to get the bits contributed by the source element #1 /// into position 0. /// 4. shift left by the proper amount to to align to the desired position in /// the result element vector. For instance, the contribution of the second /// source element for the first row needs to be shifted by `5` to form the /// first i8 result element. /// /// Eventually, we end up building the sequence /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the /// bits extracted from the source vector (i.e. the `shuffle -> and` part). struct BitCastRewriter { … }; } // namespace [[maybe_unused]] static raw_ostream & operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) { … } BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType, VectorType targetVectorType) : … { … } BitCastRewriter::BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType) : … { … } /// Verify that the precondition type meets the common preconditions for any /// conversion. static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op) { … } LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op) { … } /// Verify that source and destination element types meet the precondition for /// the supported aligned conversion cases. Alignment means that the either the /// source element type is multiple of the destination element type or the other /// way around. /// /// NOTE: This method assumes that common conversion preconditions are met. static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType srcType, VectorType dstType, Operation *op) { … } SmallVector<BitCastRewriter::Metadata> BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) { … } Value BitCastRewriter::genericRewriteStep( PatternRewriter &rewriter, Location loc, Value initialValue, Value runningResult, const BitCastRewriter::Metadata &metadata) { … } /// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and /// bitwise ops that take advantage of high-level information to avoid leaving /// LLVM to scramble with peephole optimizations. static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value srcValue) { … } /// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and /// bitwise ops that take advantage of high-level information to avoid leaving /// LLVM to scramble with peephole optimizations. static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc, Value srcValue) { … } /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise /// ops that take advantage of high-level information to avoid leaving LLVM to /// scramble with peephole optimizations. static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue) { … } namespace { /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take /// advantage of high-level information to avoid leaving LLVM to scramble with /// peephole optimizations. struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> { … }; } // namespace //===----------------------------------------------------------------------===// // RewriteExtOfBitCast //===----------------------------------------------------------------------===// namespace { /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that /// take advantage of high-level information to avoid leaving LLVM to scramble /// with peephole optimizations. template <typename ExtOpType> struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> { … }; /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and /// bitwise ops that take advantage of high-level information to avoid leaving /// LLVM to scramble with peephole optimizations. Templated to choose between /// signed and unsigned conversions. /// /// For example (signed): /// arith.extsi %in : vector<8xi4> to vector<8xi32> /// is rewriten as /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> /// %1 = arith.shli %0, 4 : vector<4xi8> /// %2 = arith.shrsi %1, 4 : vector<4xi8> /// %3 = arith.shrsi %0, 4 : vector<4xi8> /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32> /// /// arith.sitofp %in : vector<8xi4> to vector<8xf32> /// is rewriten as /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> /// %1 = arith.shli %0, 4 : vector<4xi8> /// %2 = arith.shrsi %1, 4 : vector<4xi8> /// %3 = arith.shrsi %0, 4 : vector<4xi8> /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32> /// /// Example (unsigned): /// arith.extui %in : vector<8xi4> to vector<8xi32> /// is rewritten as /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> /// %1 = arith.andi %0, 15 : vector<4xi8> /// %2 = arith.shrui %0, 4 : vector<4xi8> /// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8> /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32> /// template <typename ConversionOpType, bool isSigned> struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> { … }; /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and /// bitwise ops that take advantage of high-level information to avoid leaving /// LLVM to scramble with peephole optimizations. /// /// For example: /// arith.trunci %in : vector<8xi32> to vector<8xi4> /// is rewriten as /// /// %cst = arith.constant dense<15> : vector<4xi8> /// %cst_0 = arith.constant dense<4> : vector<4xi8> /// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8> /// %2 = arith.andi %0, %cst : vector<4xi8> /// %3 = arith.shli %1, %cst_0 : vector<4xi8> /// %4 = arith.ori %2, %3 : vector<4xi8> /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4> /// struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> { … }; /// Rewrite a sub-byte vector transpose into a sequence of instructions that /// perform the transpose on wider (byte) element types. /// For example: /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4> /// /// is rewritten as: /// /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8> /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8> /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4> /// struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Public Interface Definition //===----------------------------------------------------------------------===// void vector::populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) { … } void vector::populateVectorNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … } void vector::populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … }