llvm/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>

usingnamespacemlir;

#define DEBUG_TYPE
#define DBGS()
#define DBGSNL()
#define LDBG(X)

/// Returns a compressed mask. The mask value is set only if any mask is present
/// in the scale range. E.g., if `scale` equals to 2, the following mask:
///
///   %mask = [1, 1, 1, 0, 0, 0]
///
/// will return the following new compressed mask:
///
///   %mask = [1, 1, 0]
static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                  Location loc, Value mask,
                                                  int origElements, int scale) {}

namespace {

//===----------------------------------------------------------------------===//
// ConvertVectorStore
//===----------------------------------------------------------------------===//

struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {};

//===----------------------------------------------------------------------===//
// ConvertVectorMaskedStore
//===----------------------------------------------------------------------===//

struct ConvertVectorMaskedStore final
    : OpConversionPattern<vector::MaskedStoreOp> {};

//===----------------------------------------------------------------------===//
// ConvertVectorLoad
//===----------------------------------------------------------------------===//

struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {};

//===----------------------------------------------------------------------===//
// ConvertVectorMaskedLoad
//===----------------------------------------------------------------------===//

struct ConvertVectorMaskedLoad final
    : OpConversionPattern<vector::MaskedLoadOp> {};

//===----------------------------------------------------------------------===//
// ConvertVectorTransferRead
//===----------------------------------------------------------------------===//

struct ConvertVectorTransferRead final
    : OpConversionPattern<vector::TransferReadOp> {};
} // end anonymous namespace

//===----------------------------------------------------------------------===//
// RewriteBitCastOfTruncI
//===----------------------------------------------------------------------===//

namespace {

/// Helper struct to keep track of the provenance of a contiguous set of bits
/// in a source vector.
struct SourceElementRange {};

struct SourceElementRangeList : public SmallVector<SourceElementRange> {};

/// Helper struct to enumerate the source elements and bit ranges that are
/// involved in a bitcast operation.
/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
/// any 1-D vector shape and any source/target bitwidths.
/// This creates and holds a mapping of the form:
/// [dstVectorElementJ] ==
///    [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
/// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
///   [0] = {0, [0-8)}
///   [1] = {0, [8-16)}
///   [2] = {0, [16-24)}
/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
///   [0] = {0, [0, 10)}, {1, [0, 5)}
///   [1] = {1, [5, 10)}, {2, [0, 10)}
struct BitCastBitsEnumerator {};

/// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
/// peephole optimizations.
/// BitCastBitsEnumerator encodes for each element of the target vector the
/// provenance of the bits in the source vector. We can "transpose" this
/// information to build a sequence of shuffles and bitwise ops that will
/// produce the desired result.
//
/// Consider the following motivating example:
/// ```
///   %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
/// ```
//
/// BitCastBitsEnumerator contains the following information:
/// ```
///   { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
///   { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
///   { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
///   { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
///   { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
///   { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
///   { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
///   {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
///   {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
///   {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
///   {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
///   {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
///   {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
///   {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
///   {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
///   {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
///   {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
///   {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
///   {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
///   {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
/// ```
///
/// In the above, each row represents one target vector element and each
/// column represents one bit contribution from a source vector element.
/// The algorithm creates vector.shuffle operations (in this case there are 3
/// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
/// algorithm populates the bits as follows:
/// ```
///     src bits 0 ...
/// 1st shuffle |xxxxx   |xx      |...
/// 2nd shuffle |     xxx|  xxxxx |...
/// 3rd shuffle |        |       x|...
/// ```
//
/// The algorithm proceeds as follows:
///   1. for each vector.shuffle, collect the source vectors that participate in
///     this shuffle. One source vector per target element of the resulting
///     vector.shuffle. If there is no source element contributing bits for the
///     current vector.shuffle, take 0 (i.e. row 0 in the above example has only
///     2 columns).
///   2. represent the bitrange in the source vector as a mask. If there is no
///     source element contributing bits for the current vector.shuffle, take 0.
///   3. shift right by the proper amount to align the source bitrange at
///     position 0. This is exactly the low end of the bitrange. For instance,
///     the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
///     shift right by 3 to get the bits contributed by the source element #1
///     into position 0.
///   4. shift left by the proper amount to to align to the desired position in
///     the result element vector.  For instance, the contribution of the second
///     source element for the first row needs to be shifted by `5` to form the
///     first i8 result element.
///
/// Eventually, we end up building  the sequence
/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
struct BitCastRewriter {};

} // namespace

[[maybe_unused]] static raw_ostream &
operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {}

BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
                                             VectorType targetVectorType)
    :{}

BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
                                 VectorType targetVectorType)
    :{}

/// Verify that the precondition type meets the common preconditions for any
/// conversion.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
                                                  VectorType preconditionType,
                                                  Operation *op) {}

LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
                                                  VectorType preconditionType,
                                                  Operation *op) {}

/// Verify that source and destination element types meet the precondition for
/// the supported aligned conversion cases. Alignment means that the either the
/// source element type is multiple of the destination element type or the other
/// way around.
///
/// NOTE: This method assumes that common conversion preconditions are met.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
                                                   VectorType srcType,
                                                   VectorType dstType,
                                                   Operation *op) {}

SmallVector<BitCastRewriter::Metadata>
BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {}

Value BitCastRewriter::genericRewriteStep(
    PatternRewriter &rewriter, Location loc, Value initialValue,
    Value runningResult, const BitCastRewriter::Metadata &metadata) {}

/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations.
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
                                    Value srcValue) {}

/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations.
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
                                      Value srcValue) {}

/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
/// ops that take advantage of high-level information to avoid leaving LLVM to
/// scramble with peephole optimizations.
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
                                Value srcValue) {}

namespace {
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
/// peephole optimizations.
struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {};
} // namespace

//===----------------------------------------------------------------------===//
// RewriteExtOfBitCast
//===----------------------------------------------------------------------===//

namespace {
/// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
/// take advantage of high-level information to avoid leaving LLVM to scramble
/// with peephole optimizations.
template <typename ExtOpType>
struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {};

/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations. Templated to choose between
/// signed and unsigned conversions.
///
/// For example (signed):
///    arith.extsi %in : vector<8xi4> to vector<8xi32>
///      is rewriten as
///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
///        %1 = arith.shli %0, 4 : vector<4xi8>
///        %2 = arith.shrsi %1, 4 : vector<4xi8>
///        %3 = arith.shrsi %0, 4 : vector<4xi8>
///        %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
///        %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
///
///    arith.sitofp %in : vector<8xi4> to vector<8xf32>
///      is rewriten as
///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
///        %1 = arith.shli %0, 4 : vector<4xi8>
///        %2 = arith.shrsi %1, 4 : vector<4xi8>
///        %3 = arith.shrsi %0, 4 : vector<4xi8>
///        %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
///        %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
///
/// Example (unsigned):
///    arith.extui %in : vector<8xi4> to vector<8xi32>
///      is rewritten as
///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
///        %1 = arith.andi %0, 15 : vector<4xi8>
///        %2 = arith.shrui %0, 4 : vector<4xi8>
///        %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
///        %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
///
template <typename ConversionOpType, bool isSigned>
struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {};

/// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations.
///
/// For example:
///    arith.trunci %in : vector<8xi32> to vector<8xi4>
///      is rewriten as
///
///        %cst = arith.constant dense<15> : vector<4xi8>
///        %cst_0 = arith.constant dense<4> : vector<4xi8>
///        %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
///        %2 = arith.andi %0, %cst : vector<4xi8>
///        %3 = arith.shli %1, %cst_0 : vector<4xi8>
///        %4 = arith.ori %2, %3 : vector<4xi8>
///        %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
///
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {};

/// Rewrite a sub-byte vector transpose into a sequence of instructions that
/// perform the transpose on wider (byte) element types.
/// For example:
///   %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
///
///   is rewritten as:
///
///   %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
///   %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
///   %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
///
struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {};

} // namespace

//===----------------------------------------------------------------------===//
// Public Interface Definition
//===----------------------------------------------------------------------===//

void vector::populateVectorNarrowTypeEmulationPatterns(
    const arith::NarrowTypeEmulationConverter &typeConverter,
    RewritePatternSet &patterns) {}

void vector::populateVectorNarrowTypeRewritePatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {}

void vector::populateVectorTransposeNarrowTypeRewritePatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {}