#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include <cassert>
#include <cstdint>
#include <numeric>
#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
usingnamespacemlir;
usingnamespacemlir::vector;
enum class MaskFormat { … };
static MaskFormat getMaskFormat(Value mask) { … }
void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) { … }
static bool isSupportedCombiningKind(CombiningKind combiningKind,
Type elementType) { … }
AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
VectorType vectorType) { … }
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
vector::TransferReadOp read) { … }
bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
vector::TransferReadOp read) { … }
bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
vector::TransferWriteOp priorWrite) { … }
bool mlir::vector::isDisjointTransferIndices(
VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
bool testDynamicValueUsingBounds) { … }
bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB,
bool testDynamicValueUsingBounds) { … }
static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
ArrayRef<int64_t> shape,
ArrayRef<int64_t> offsets) { … }
SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) { … }
SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) { … }
SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> foldResults) { … }
std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) { … }
namespace mlir {
namespace vector {
namespace detail {
struct BitmaskEnumStorage : public AttributeStorage { … };
}
}
}
namespace {
struct VectorInlinerInterface : public DialectInlinerInterface { … };
}
void VectorDialect::initialize() { … }
Operation *VectorDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) { … }
IntegerType vector::getVectorSubscriptType(Builder &builder) { … }
ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
ArrayRef<int64_t> values) { … }
void vector::MultiDimReductionOp::build(OpBuilder &builder,
OperationState &result, Value source,
Value acc, ArrayRef<bool> reductionMask,
CombiningKind kind) { … }
OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) { … }
std::optional<SmallVector<int64_t, 4>>
MultiDimReductionOp::getShapeForUnroll() { … }
LogicalResult MultiDimReductionOp::verify() { … }
Type MultiDimReductionOp::getExpectedMaskType() { … }
namespace {
struct ElideUnitDimsInMultiDimReduction
: public OpRewritePattern<MultiDimReductionOp> { … };
}
void MultiDimReductionOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) { … }
void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
CombiningKind kind, Value vector,
arith::FastMathFlags fastMathFlags) { … }
void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
CombiningKind kind, Value vector, Value acc,
arith::FastMathFlags fastMathFlags) { … }
LogicalResult ReductionOp::verify() { … }
Type ReductionOp::getExpectedMaskType() { … }
Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
OpBuilder &builder, Location loc,
Value vector) { … }
std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() { … }
namespace {
struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> { … };
}
void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
Value lhs, Value rhs, Value acc,
ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
ArrayRef<IteratorType> iteratorTypes) { … }
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
Value lhs, Value rhs, Value acc,
ArrayAttr indexingMaps,
ArrayAttr iteratorTypes) { … }
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
Value lhs, Value rhs, Value acc,
ArrayAttr indexingMaps,
ArrayAttr iteratorTypes, CombiningKind kind) { … }
ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { … }
void ContractionOp::print(OpAsmPrinter &p) { … }
static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
const std::vector<std::pair<int64_t, int64_t>> &map) { … }
static LogicalResult verifyOutputShape(
ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
Type resType,
const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) { … }
LogicalResult ContractionOp::verify() { … }
Type ContractionOp::getExpectedMaskType() { … }
SmallVector<StringRef> ContractionOp::getTraitAttrNames() { … }
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) { … }
static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
IteratorType targetIteratorType, MLIRContext *context) { … }
void ContractionOp::getIterationBounds(
SmallVectorImpl<int64_t> &iterationBounds) { … }
void ContractionOp::getIterationIndexMap(
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) { … }
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() { … }
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() { … }
std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() { … }
template <typename AddOpType>
struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> { … };
void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
Value source) { … }
LogicalResult vector::ExtractElementOp::verify() { … }
OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { … }
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, int64_t position) { … }
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, OpFoldResult position) { … }
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, ArrayRef<int64_t> position) { … }
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, ArrayRef<OpFoldResult> position) { … }
LogicalResult
ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ExtractOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
LogicalResult vector::ExtractOp::verify() { … }
template <typename IntType>
static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) { … }
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { … }
namespace {
class ExtractFromInsertTransposeChainState { … };
}
ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
ExtractOp e)
: … { … }
LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() { … }
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
Value &res) { … }
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { … }
Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
Value source) { … }
Value ExtractFromInsertTransposeChainState::fold() { … }
static bool hasZeroDimVectors(Operation *op) { … }
static Value foldExtractFromBroadcast(ExtractOp extractOp) { … }
static Value foldExtractFromShapeCast(ExtractOp extractOp) { … }
static Value foldExtractFromExtractStrided(ExtractOp extractOp) { … }
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) { … }
static Value foldScalarExtractFromFromElements(ExtractOp extractOp) { … }
OpFoldResult ExtractOp::fold(FoldAdaptor) { … }
namespace {
class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { … };
class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> { … };
class ExtractOpNonSplatConstantFolder final
: public OpRewritePattern<ExtractOp> { … };
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> { … };
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
PatternRewriter &rewriter) { … }
LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
PatternRewriter &rewriter) { … }
}
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
SmallVectorImpl<int64_t> &results) { … }
std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() { … }
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
PatternRewriter &rewriter) { … }
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
static llvm::SetVector<int64_t>
computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> dstShape) { … }
llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() { … }
Value BroadcastOp::createOrFoldBroadcastOp(
OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
const llvm::SetVector<int64_t> &broadcastedDims) { … }
BroadcastableToResult mlir::vector::isBroadcastableTo(
Type srcType, VectorType dstVectorType,
std::pair<VectorDim, VectorDim> *mismatchingDims) { … }
LogicalResult BroadcastOp::verify() { … }
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { … }
namespace {
struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { … };
}
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
LogicalResult ShuffleOp::verify() { … }
LogicalResult
ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ShuffleOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) { … }
template <typename T>
static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) { … }
OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) { … }
namespace {
struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> { … };
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> { … };
class ShuffleInterleave : public OpRewritePattern<ShuffleOp> { … };
}
void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
void InsertElementOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest) { … }
LogicalResult InsertElementOp::verify() { … }
OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) { … }
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest, int64_t position) { … }
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest, OpFoldResult position) { … }
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest,
ArrayRef<int64_t> position) { … }
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest,
ArrayRef<OpFoldResult> position) { … }
LogicalResult InsertOp::verify() { … }
namespace {
class InsertToBroadcast final : public OpRewritePattern<InsertOp> { … };
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> { … };
class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> { … };
}
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { … }
void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> strides) { … }
template <typename OpType>
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
ArrayAttr arrayAttr,
ArrayRef<int64_t> shape,
StringRef attrName) { … }
template <typename OpType>
static LogicalResult
isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
int64_t max, StringRef attrName,
bool halfOpen = true) { … }
template <typename OpType>
static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
ArrayRef<int64_t> shape, StringRef attrName,
bool halfOpen = true, int64_t min = 0) { … }
template <typename OpType>
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
bool halfOpen = true, int64_t min = 1) { … }
static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
MLIRContext *context) { … }
LogicalResult InsertStridedSliceOp::verify() { … }
namespace {
class FoldInsertStridedSliceSplat final
: public OpRewritePattern<InsertStridedSliceOp> { … };
class FoldInsertStridedSliceOfExtract final
: public OpRewritePattern<InsertStridedSliceOp> { … };
class InsertStridedSliceConstantFolder final
: public OpRewritePattern<InsertStridedSliceOp> { … };
}
void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) { … }
OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) { … }
void OuterProductOp::build(OpBuilder &builder, OperationState &result,
Value lhs, Value rhs, Value acc) { … }
void OuterProductOp::print(OpAsmPrinter &p) { … }
ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) { … }
LogicalResult OuterProductOp::verify() { … }
Type OuterProductOp::getExpectedMaskType() { … }
static Type inferStridedSliceOpResultType(VectorType vectorType,
ArrayAttr offsets, ArrayAttr sizes,
ArrayAttr strides) { … }
void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
Value source, ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) { … }
LogicalResult ExtractStridedSliceOp::verify() { … }
static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { … }
OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { … }
void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) { … }
namespace {
class StridedSliceConstantMaskFolder final
: public OpRewritePattern<ExtractStridedSliceOp> { … };
class StridedSliceSplatConstantFolder final
: public OpRewritePattern<ExtractStridedSliceOp> { … };
class StridedSliceNonSplatConstantFolder final
: public OpRewritePattern<ExtractStridedSliceOp> { … };
class StridedSliceBroadcast final
: public OpRewritePattern<ExtractStridedSliceOp> { … };
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { … };
}
void ExtractStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) { … }
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices, AffineMapAttr permutationMapAttr,
ArrayAttr inBoundsAttr) { … }
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices, AffineMap permutationMap,
std::optional<ArrayRef<bool>> inBounds) { … }
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices, Value padding,
std::optional<ArrayRef<bool>> inBounds) { … }
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices,
std::optional<ArrayRef<bool>> inBounds) { … }
template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) { … }
static LogicalResult
verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
VectorType vectorType, VectorType maskType,
VectorType inferredMaskType, AffineMap permutationMap,
ArrayAttr inBounds) { … }
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { … }
void TransferReadOp::print(OpAsmPrinter &p) { … }
VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
AffineMap permMap) { … }
ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { … }
LogicalResult TransferReadOp::verify() { … }
Type TransferReadOp::getExpectedMaskType() { … }
template <typename TransferOp>
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { … }
template <typename TransferOp>
static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { … }
template <typename TransferOp>
static LogicalResult foldTransferFullMask(TransferOp op) { … }
static Value foldRAW(TransferReadOp readOp) { … }
OpFoldResult TransferReadOp::fold(FoldAdaptor) { … }
std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() { … }
void TransferReadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) { … }
namespace {
struct TransferReadAfterWriteToBroadcast
: public OpRewritePattern<TransferReadOp> { … };
}
void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value dest, ValueRange indices,
AffineMapAttr permutationMapAttr,
Value mask,
ArrayAttr inBoundsAttr) { … }
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value dest, ValueRange indices,
AffineMapAttr permutationMapAttr,
ArrayAttr inBoundsAttr) { … }
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value dest, ValueRange indices,
AffineMap permutationMap,
std::optional<ArrayRef<bool>> inBounds) { … }
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value dest, ValueRange indices,
std::optional<ArrayRef<bool>> inBounds) { … }
ParseResult TransferWriteOp::parse(OpAsmParser &parser,
OperationState &result) { … }
void TransferWriteOp::print(OpAsmPrinter &p) { … }
LogicalResult TransferWriteOp::verify() { … }
Type TransferWriteOp::getExpectedMaskType() { … }
static LogicalResult foldReadInitWrite(TransferWriteOp write,
ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &results) { … }
static bool checkSameValueWAR(vector::TransferReadOp read,
vector::TransferWriteOp write) { … }
static LogicalResult foldWAR(TransferWriteOp write,
SmallVectorImpl<OpFoldResult> &results) { … }
LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() { … }
void TransferWriteOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) { … }
namespace {
class FoldWaw final : public OpRewritePattern<TransferWriteOp> { … };
struct SwapExtractSliceOfTransferWrite
: public OpRewritePattern<tensor::InsertSliceOp> { … };
}
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
VectorType vecTy,
MemRefType memRefTy) { … }
LogicalResult vector::LoadOp::verify() { … }
OpFoldResult LoadOp::fold(FoldAdaptor) { … }
LogicalResult vector::StoreOp::verify() { … }
LogicalResult StoreOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
LogicalResult MaskedLoadOp::verify() { … }
namespace {
class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { … };
}
void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
OpFoldResult MaskedLoadOp::fold(FoldAdaptor) { … }
LogicalResult MaskedStoreOp::verify() { … }
namespace {
class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { … };
}
void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
LogicalResult GatherOp::verify() { … }
Type GatherOp::getExpectedMaskType() { … }
std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() { … }
namespace {
class GatherFolder final : public OpRewritePattern<GatherOp> { … };
}
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
LogicalResult ScatterOp::verify() { … }
namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> { … };
}
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
LogicalResult ExpandLoadOp::verify() { … }
namespace {
class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { … };
}
void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
LogicalResult CompressStoreOp::verify() { … }
namespace {
class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { … };
}
void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { … }
static LogicalResult verifyVectorShapeCast(Operation *op,
VectorType sourceVectorType,
VectorType resultVectorType) { … }
LogicalResult ShapeCastOp::verify() { … }
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { … }
namespace {
class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> { … };
static VectorType trimTrailingOneDims(VectorType oldType) { … }
class ShapeCastCreateMaskFolderTrailingOneDim final
: public OpRewritePattern<ShapeCastOp> { … };
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> { … };
}
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
LogicalResult BitCastOp::verify() { … }
OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) { … }
static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) { … }
void TypeCastOp::build(OpBuilder &builder, OperationState &result,
Value source) { … }
LogicalResult TypeCastOp::verify() { … }
void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
Value vector, ArrayRef<int64_t> permutation) { … }
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { … }
LogicalResult vector::TransposeOp::verify() { … }
std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() { … }
namespace {
class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { … };
struct FoldTransposedScalarBroadcast final
: public OpRewritePattern<vector::TransposeOp> { … };
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> { … };
class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> { … };
}
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) { … }
void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
VectorType type, ConstantMaskKind kind) { … }
LogicalResult ConstantMaskOp::verify() { … }
bool ConstantMaskOp::isAllOnesMask() { … }
void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
VectorType type,
ArrayRef<OpFoldResult> mixedOperands) { … }
LogicalResult CreateMaskOp::verify() { … }
namespace {
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { … };
}
void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
void MaskOp::build(
OpBuilder &builder, OperationState &result, Value mask,
Operation *maskableOp,
function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) { … }
void MaskOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTypes,
Value mask, Operation *maskableOp,
function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) { … }
void MaskOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTypes,
Value mask, Value passthru, Operation *maskableOp,
function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) { … }
ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) { … }
void mlir::vector::MaskOp::print(OpAsmPrinter &p) { … }
void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { … }
LogicalResult MaskOp::verify() { … }
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> { … };
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
Operation *MaskOp::getMaskableOp() { … }
bool MaskOp::hasPassthru() { … }
LogicalResult ScanOp::verify() { … }
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) { … }
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { … }
OpFoldResult StepOp::fold(FoldAdaptor adaptor) { … }
void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) { … }
ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
OperationState &result) { … }
void WarpExecuteOnLane0Op::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { … }
void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, Value laneId,
int64_t warpSize) { … }
void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, Value laneId,
int64_t warpSize, ValueRange args,
TypeRange blockArgTypes) { … }
static LogicalResult verifyDistributedType(Type expanded, Type distributed,
int64_t warpSize, Operation *op) { … }
LogicalResult WarpExecuteOnLane0Op::verify() { … }
bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) { … }
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value acc,
arith::FastMathFlagsAttr fastmath,
Value mask) { … }
void mlir::vector::createMaskOpRegion(OpBuilder &builder,
Operation *maskableOp) { … }
Operation *mlir::vector::maskOperation(OpBuilder &builder,
Operation *maskableOp, Value mask,
Value passthru) { … }
Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
Value newValue, Value passthru) { … }
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"