#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/Casting.h"
#include <optional>
usingnamespacemlir;
usingnamespacemlir::vector;
static VectorType reducedVectorTypeBack(VectorType tp) { … }
static Value insertOne(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, Location loc,
Value val1, Value val2, Type llvmType, int64_t rank,
int64_t pos) { … }
static Value extractOne(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, Location loc,
Value val, Type llvmType, int64_t rank, int64_t pos) { … }
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) { … }
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
const LLVMTypeConverter &converter) { … }
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter &typeConverter,
MemRefType memRefType, Value llvmMemref, Value base,
Value index, VectorType vectorType) { … }
static Value getAsLLVMValue(OpBuilder &builder, Location loc,
OpFoldResult foldResult) { … }
namespace {
VectorScaleOpConversion;
class VectorBitCastOpConversion
: public ConvertOpToLLVMPattern<vector::BitCastOp> { … };
class VectorMatmulOpConversion
: public ConvertOpToLLVMPattern<vector::MatmulOp> { … };
class VectorFlatTransposeOpConversion
: public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { … };
static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
vector::LoadOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) { … }
static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
vector::MaskedLoadOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) { … }
static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
vector::StoreOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) { … }
static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
vector::MaskedStoreOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) { … }
template <class LoadOrStoreOp>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { … };
class VectorGatherOpConversion
: public ConvertOpToLLVMPattern<vector::GatherOp> { … };
class VectorScatterOpConversion
: public ConvertOpToLLVMPattern<vector::ScatterOp> { … };
class VectorExpandLoadOpConversion
: public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { … };
class VectorCompressStoreOpConversion
: public ConvertOpToLLVMPattern<vector::CompressStoreOp> { … };
class ReductionNeutralZero { … };
class ReductionNeutralIntOne { … };
class ReductionNeutralFPOne { … };
class ReductionNeutralAllOnes { … };
class ReductionNeutralSIntMin { … };
class ReductionNeutralUIntMin { … };
class ReductionNeutralSIntMax { … };
class ReductionNeutralUIntMax { … };
class ReductionNeutralFPMin { … };
class ReductionNeutralFPMax { … };
static Value createReductionNeutralValue(ReductionNeutralZero neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) { … }
static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) { … }
static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) { … }
static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) { … }
static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) { … }
static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) { … }
static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) { … }
static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) { … }
static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) { … }
static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) { … }
template <class ReductionNeutral>
static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value accumulator) { … }
static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) { … }
template <class LLVMRedIntrinOp, class ScalarOp>
static Value createIntegerReductionArithmeticOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator) { … }
template <class LLVMRedIntrinOp>
static Value createIntegerReductionComparisonOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { … }
namespace {
template <typename Source>
struct VectorToScalarMapper;
template <>
struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> { … };
template <>
struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> { … };
template <>
struct VectorToScalarMapper<LLVM::vector_reduce_fmax> { … };
template <>
struct VectorToScalarMapper<LLVM::vector_reduce_fmin> { … };
}
template <class LLVMRedIntrinOp>
static Value createFPReductionComparisonOpLowering(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { … }
class MaskNeutralFMaximum { … };
class MaskNeutralFMinimum { … };
static llvm::APFloat
getMaskNeutralValue(MaskNeutralFMaximum,
const llvm::fltSemantics &floatSemantics) { … }
static llvm::APFloat
getMaskNeutralValue(MaskNeutralFMinimum,
const llvm::fltSemantics &floatSemantics) { … }
template <typename MaskNeutral>
static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Type vectorType) { … }
template <class LLVMRedIntrinOp, class MaskNeutral>
static Value
lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand, Value accumulator,
Value mask, LLVM::FastmathFlagsAttr fmf) { … }
template <class LLVMRedIntrinOp, class ReductionNeutral>
static Value
lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
Type llvmType, Value vectorOperand,
Value accumulator, LLVM::FastmathFlagsAttr fmf) { … }
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value
lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
Location loc, Type llvmType,
Value vectorOperand, Value accumulator) { … }
template <class LLVMVPRedIntrinOp, class ReductionNeutral>
static Value lowerPredicatedReductionWithStartValue(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, Value mask) { … }
template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
static Value lowerPredicatedReductionWithStartValue(
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
Value vectorOperand, Value accumulator, Value mask) { … }
class VectorReductionOpConversion
: public ConvertOpToLLVMPattern<vector::ReductionOp> { … };
template <class MaskedOp>
class VectorMaskOpConversionBase
: public ConvertOpToLLVMPattern<vector::MaskOp> { … };
class MaskedReductionOpConversion
: public VectorMaskOpConversionBase<vector::ReductionOp> { … };
class VectorShuffleOpConversion
: public ConvertOpToLLVMPattern<vector::ShuffleOp> { … };
class VectorExtractElementOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractElementOp> { … };
class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> { … };
class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { … };
class VectorInsertElementOpConversion
: public ConvertOpToLLVMPattern<vector::InsertElementOp> { … };
class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> { … };
struct VectorScalableInsertOpLowering
: public ConvertOpToLLVMPattern<vector::ScalableInsertOp> { … };
struct VectorScalableExtractOpLowering
: public ConvertOpToLLVMPattern<vector::ScalableExtractOp> { … };
class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { … };
static std::optional<SmallVector<int64_t, 4>>
computeContiguousStrides(MemRefType memRefType) { … }
class VectorTypeCastOpConversion
: public ConvertOpToLLVMPattern<vector::TypeCastOp> { … };
class VectorCreateMaskOpRewritePattern
: public OpRewritePattern<vector::CreateMaskOp> { … };
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { … };
struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { … };
struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { … };
struct VectorInterleaveOpLowering
: public ConvertOpToLLVMPattern<vector::InterleaveOp> { … };
struct VectorDeinterleaveOpLowering
: public ConvertOpToLLVMPattern<vector::DeinterleaveOp> { … };
struct VectorFromElementsLowering
: public ConvertOpToLLVMPattern<vector::FromElementsOp> { … };
struct VectorStepOpLowering : public ConvertOpToLLVMPattern<vector::StepOp> { … };
}
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions, bool force32BitVectorIndices) { … }
void mlir::populateVectorToLLVMMatrixConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) { … }