llvm/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"

#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/Casting.h"
#include <optional>

usingnamespacemlir;
usingnamespacemlir::vector;

// Helper to reduce vector type by *all* but one rank at back.
static VectorType reducedVectorTypeBack(VectorType tp) {}

// Helper that picks the proper sequence for inserting.
static Value insertOne(ConversionPatternRewriter &rewriter,
                       const LLVMTypeConverter &typeConverter, Location loc,
                       Value val1, Value val2, Type llvmType, int64_t rank,
                       int64_t pos) {}

// Helper that picks the proper sequence for extracting.
static Value extractOne(ConversionPatternRewriter &rewriter,
                        const LLVMTypeConverter &typeConverter, Location loc,
                        Value val, Type llvmType, int64_t rank, int64_t pos) {}

// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
                                 MemRefType memrefType, unsigned &align) {}

// Check if the last stride is non-unit and has a valid memory space.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
                                           const LLVMTypeConverter &converter) {}

// Add an index vector component to a base pointer.
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
                            const LLVMTypeConverter &typeConverter,
                            MemRefType memRefType, Value llvmMemref, Value base,
                            Value index, VectorType vectorType) {}

/// Convert `foldResult` into a Value. Integer attribute is converted to
/// an LLVM constant op.
static Value getAsLLVMValue(OpBuilder &builder, Location loc,
                            OpFoldResult foldResult) {}

namespace {

/// Trivial Vector to LLVM conversions
VectorScaleOpConversion;

/// Conversion pattern for a vector.bitcast.
class VectorBitCastOpConversion
    : public ConvertOpToLLVMPattern<vector::BitCastOp> {};

/// Conversion pattern for a vector.matrix_multiply.
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
class VectorMatmulOpConversion
    : public ConvertOpToLLVMPattern<vector::MatmulOp> {};

/// Conversion pattern for a vector.flat_transpose.
/// This is lowered directly to the proper llvm.intr.matrix.transpose.
class VectorFlatTransposeOpConversion
    : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {};

/// Overloaded utility that replaces a vector.load, vector.store,
/// vector.maskedload and vector.maskedstore with their respective LLVM
/// couterparts.
static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
                                 vector::LoadOpAdaptor adaptor,
                                 VectorType vectorTy, Value ptr, unsigned align,
                                 ConversionPatternRewriter &rewriter) {}

static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
                                 vector::MaskedLoadOpAdaptor adaptor,
                                 VectorType vectorTy, Value ptr, unsigned align,
                                 ConversionPatternRewriter &rewriter) {}

static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
                                 vector::StoreOpAdaptor adaptor,
                                 VectorType vectorTy, Value ptr, unsigned align,
                                 ConversionPatternRewriter &rewriter) {}

static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
                                 vector::MaskedStoreOpAdaptor adaptor,
                                 VectorType vectorTy, Value ptr, unsigned align,
                                 ConversionPatternRewriter &rewriter) {}

/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
/// vector.maskedstore.
template <class LoadOrStoreOp>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {};

/// Conversion pattern for a vector.gather.
class VectorGatherOpConversion
    : public ConvertOpToLLVMPattern<vector::GatherOp> {};

/// Conversion pattern for a vector.scatter.
class VectorScatterOpConversion
    : public ConvertOpToLLVMPattern<vector::ScatterOp> {};

/// Conversion pattern for a vector.expandload.
class VectorExpandLoadOpConversion
    : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {};

/// Conversion pattern for a vector.compressstore.
class VectorCompressStoreOpConversion
    : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {};

/// Reduction neutral classes for overloading.
class ReductionNeutralZero {};
class ReductionNeutralIntOne {};
class ReductionNeutralFPOne {};
class ReductionNeutralAllOnes {};
class ReductionNeutralSIntMin {};
class ReductionNeutralUIntMin {};
class ReductionNeutralSIntMax {};
class ReductionNeutralUIntMax {};
class ReductionNeutralFPMin {};
class ReductionNeutralFPMax {};

/// Create the reduction neutral zero value.
static Value createReductionNeutralValue(ReductionNeutralZero neutral,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Type llvmType) {}

/// Create the reduction neutral integer one value.
static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Type llvmType) {}

/// Create the reduction neutral fp one value.
static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Type llvmType) {}

/// Create the reduction neutral all-ones value.
static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Type llvmType) {}

/// Create the reduction neutral signed int minimum value.
static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Type llvmType) {}

/// Create the reduction neutral unsigned int minimum value.
static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Type llvmType) {}

/// Create the reduction neutral signed int maximum value.
static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Type llvmType) {}

/// Create the reduction neutral unsigned int maximum value.
static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Type llvmType) {}

/// Create the reduction neutral fp minimum value.
static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Type llvmType) {}

/// Create the reduction neutral fp maximum value.
static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Type llvmType) {}

/// Returns `accumulator` if it has a valid value. Otherwise, creates and
/// returns a new accumulator value using `ReductionNeutral`.
template <class ReductionNeutral>
static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
                                    Location loc, Type llvmType,
                                    Value accumulator) {}

/// Creates a value with the 1-D vector shape provided in `llvmType`.
/// This is used as effective vector length by some intrinsics supporting
/// dynamic vector lengths at runtime.
static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
                                     Location loc, Type llvmType) {}

/// Helper method to lower a `vector.reduction` op that performs an arithmetic
/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
/// and `ScalarOp` is the scalar operation used to add the accumulation value if
/// non-null.
template <class LLVMRedIntrinOp, class ScalarOp>
static Value createIntegerReductionArithmeticOpLowering(
    ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
    Value vectorOperand, Value accumulator) {}

/// Helper method to lower a `vector.reduction` operation that performs
/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
/// intrinsic to use and `predicate` is the predicate to use to compare+combine
/// the accumulator value if non-null.
template <class LLVMRedIntrinOp>
static Value createIntegerReductionComparisonOpLowering(
    ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
    Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {}

namespace {
template <typename Source>
struct VectorToScalarMapper;
template <>
struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {};
template <>
struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {};
template <>
struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {};
template <>
struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {};
} // namespace

template <class LLVMRedIntrinOp>
static Value createFPReductionComparisonOpLowering(
    ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
    Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {}

/// Reduction neutral classes for overloading
class MaskNeutralFMaximum {};
class MaskNeutralFMinimum {};

/// Get the mask neutral floating point maximum value
static llvm::APFloat
getMaskNeutralValue(MaskNeutralFMaximum,
                    const llvm::fltSemantics &floatSemantics) {}
/// Get the mask neutral floating point minimum value
static llvm::APFloat
getMaskNeutralValue(MaskNeutralFMinimum,
                    const llvm::fltSemantics &floatSemantics) {}

/// Create the mask neutral floating point MLIR vector constant
template <typename MaskNeutral>
static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
                                    Location loc, Type llvmType,
                                    Type vectorType) {}

/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
/// `fmaximum`/`fminimum`.
/// More information: https://github.com/llvm/llvm-project/issues/64940
template <class LLVMRedIntrinOp, class MaskNeutral>
static Value
lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
                                Location loc, Type llvmType,
                                Value vectorOperand, Value accumulator,
                                Value mask, LLVM::FastmathFlagsAttr fmf) {}

template <class LLVMRedIntrinOp, class ReductionNeutral>
static Value
lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
                             Type llvmType, Value vectorOperand,
                             Value accumulator, LLVM::FastmathFlagsAttr fmf) {}

/// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic
/// that requires a start value. This start value format spans across fp
/// reductions without mask and all the masked reduction intrinsics.
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value
lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
                                       Location loc, Type llvmType,
                                       Value vectorOperand, Value accumulator) {}

template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value lowerPredicatedReductionWithStartValue(
    ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
    Value vectorOperand, Value accumulator, Value mask) {}

template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
          class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
static Value lowerPredicatedReductionWithStartValue(
    ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
    Value vectorOperand, Value accumulator, Value mask) {}

/// Conversion pattern for all vector reductions.
class VectorReductionOpConversion
    : public ConvertOpToLLVMPattern<vector::ReductionOp> {};

/// Base class to convert a `vector.mask` operation while matching traits
/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
/// method performs a second match against the maskable operation `MaskedOp`.
/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
/// implemented by the concrete conversion classes. This method can match
/// against specific traits of the `vector.mask` and the maskable operation. It
/// must replace the `vector.mask` operation.
template <class MaskedOp>
class VectorMaskOpConversionBase
    : public ConvertOpToLLVMPattern<vector::MaskOp> {};

class MaskedReductionOpConversion
    : public VectorMaskOpConversionBase<vector::ReductionOp> {};

class VectorShuffleOpConversion
    : public ConvertOpToLLVMPattern<vector::ShuffleOp> {};

class VectorExtractElementOpConversion
    : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {};

class VectorExtractOpConversion
    : public ConvertOpToLLVMPattern<vector::ExtractOp> {};

/// Conversion pattern that turns a vector.fma on a 1-D vector
/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
/// This does not match vectors of n >= 2 rank.
///
/// Example:
/// ```
///  vector.fma %a, %a, %a : vector<8xf32>
/// ```
/// is converted to:
/// ```
///  llvm.intr.fmuladd %va, %va, %va:
///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
///    -> !llvm."<8 x f32>">
/// ```
class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {};

class VectorInsertElementOpConversion
    : public ConvertOpToLLVMPattern<vector::InsertElementOp> {};

class VectorInsertOpConversion
    : public ConvertOpToLLVMPattern<vector::InsertOp> {};

/// Lower vector.scalable.insert ops to LLVM vector.insert
struct VectorScalableInsertOpLowering
    : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {};

/// Lower vector.scalable.extract ops to LLVM vector.extract
struct VectorScalableExtractOpLowering
    : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {};

/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
///
/// Example:
/// ```
///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
/// ```
/// is rewritten into:
/// ```
///  %r = splat %f0: vector<2x4xf32>
///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
///  // %r3 holds the final value.
/// ```
class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {};

/// Returns the strides if the memory underlying `memRefType` has a contiguous
/// static layout.
static std::optional<SmallVector<int64_t, 4>>
computeContiguousStrides(MemRefType memRefType) {}

class VectorTypeCastOpConversion
    : public ConvertOpToLLVMPattern<vector::TypeCastOp> {};

/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
/// Non-scalable versions of this operation are handled in Vector Transforms.
class VectorCreateMaskOpRewritePattern
    : public OpRewritePattern<vector::CreateMaskOp> {};

class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {};

/// The Splat operation is lowered to an insertelement + a shufflevector
/// operation. Splat to only 0-d and 1-d vector result types are lowered.
struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {};

/// The Splat operation is lowered to an insertelement + a shufflevector
/// operation. Splat to only 2+-d vector result types are lowered by the
/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {};

/// Conversion pattern for a `vector.interleave`.
/// This supports fixed-sized vectors and scalable vectors.
struct VectorInterleaveOpLowering
    : public ConvertOpToLLVMPattern<vector::InterleaveOp> {};

/// Conversion pattern for a `vector.deinterleave`.
/// This supports fixed-sized vectors and scalable vectors.
struct VectorDeinterleaveOpLowering
    : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {};

/// Conversion pattern for a `vector.from_elements`.
struct VectorFromElementsLowering
    : public ConvertOpToLLVMPattern<vector::FromElementsOp> {};

/// Conversion pattern for vector.step.
struct VectorStepOpLowering : public ConvertOpToLLVMPattern<vector::StepOp> {};

} // namespace

/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
    bool reassociateFPReductions, bool force32BitVectorIndices) {}

void mlir::populateVectorToLLVMMatrixConversionPatterns(
    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {}