llvm/mlir/lib/Dialect/Vector/IR/VectorOps.cpp

//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements convenience types for working with super-vectorization
// operations, in particular super-vector loads and stores.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/IR/VectorOps.h"

#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"

#include <cassert>
#include <cstdint>
#include <numeric>

#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
// Pull in all enum type and utility function definitions.
#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"

usingnamespacemlir;
usingnamespacemlir::vector;

/// Helper enum to classify mask value.
enum class MaskFormat {};

/// Helper method to classify a mask value. Currently, the method
/// looks "under the hood" of a constant value with dense attributes
/// and a constant mask operation (since the client may be called at
/// various stages during progressive lowering).
static MaskFormat getMaskFormat(Value mask) {}

/// Default callback to build a region with a 'vector.yield' terminator with no
/// arguments.
void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) {}

// Helper for verifying combining kinds in contractions and reductions.
static bool isSupportedCombiningKind(CombiningKind combiningKind,
                                     Type elementType) {}

AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
                                                    VectorType vectorType) {}

/// Check if `write` is of a constant splat and the masked `read` is padded with
/// the same splat value -- meaning it could be the same value as the initial
/// constant splat.
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
                                                 vector::TransferReadOp read) {}

bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
                                     vector::TransferReadOp read) {}

bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
                                     vector::TransferWriteOp priorWrite) {}

bool mlir::vector::isDisjointTransferIndices(
    VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
    bool testDynamicValueUsingBounds) {}

bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
                                         VectorTransferOpInterface transferB,
                                         bool testDynamicValueUsingBounds) {}

// Helper to iterate over n-D vector slice elements. Calculate the next
// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
// Modifies the `position` in place. Returns a failure when `position` becomes
// the end position.
static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
                                      ArrayRef<int64_t> shape,
                                      ArrayRef<int64_t> offsets) {}

/// Returns the integer numbers in `values`. `values` are expected to be
/// constant operations.
SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) {}

/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
/// be constant operations.
SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {}

/// Convert `foldResults` into Values. Integer attributes are converted to
/// constant op.
SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
                                       ArrayRef<OpFoldResult> foldResults) {}

std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {}

//===----------------------------------------------------------------------===//
// CombiningKindAttr
//===----------------------------------------------------------------------===//

namespace mlir {
namespace vector {
namespace detail {
struct BitmaskEnumStorage : public AttributeStorage {};
} // namespace detail
} // namespace vector
} // namespace mlir

//===----------------------------------------------------------------------===//
// VectorDialect
//===----------------------------------------------------------------------===//

namespace {
/// This class defines the interface for handling inlining with vector dialect
/// operations.
struct VectorInlinerInterface : public DialectInlinerInterface {};
} // namespace

void VectorDialect::initialize() {}

/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *VectorDialect::materializeConstant(OpBuilder &builder,
                                              Attribute value, Type type,
                                              Location loc) {}

IntegerType vector::getVectorSubscriptType(Builder &builder) {}

ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
                                         ArrayRef<int64_t> values) {}

//===----------------------------------------------------------------------===//
// MultiDimReductionOp
//===----------------------------------------------------------------------===//

void vector::MultiDimReductionOp::build(OpBuilder &builder,
                                        OperationState &result, Value source,
                                        Value acc, ArrayRef<bool> reductionMask,
                                        CombiningKind kind) {}

OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {}

std::optional<SmallVector<int64_t, 4>>
MultiDimReductionOp::getShapeForUnroll() {}

LogicalResult MultiDimReductionOp::verify() {}

/// Returns the mask type expected by this operation.
Type MultiDimReductionOp::getExpectedMaskType() {}

namespace {
// Only unit dimensions that are being reduced are folded. If the dimension is
// unit, but not reduced, it is not folded, thereby keeping the output type the
// same. If not all dimensions which are reduced are of unit dimension, this
// transformation does nothing. This is just a generalization of
// ElideSingleElementReduction for ReduceOp.
struct ElideUnitDimsInMultiDimReduction
    : public OpRewritePattern<MultiDimReductionOp> {};
} // namespace

void MultiDimReductionOp::getCanonicalizationPatterns(
    RewritePatternSet &results, MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//

void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
                                CombiningKind kind, Value vector,
                                arith::FastMathFlags fastMathFlags) {}

void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
                                CombiningKind kind, Value vector, Value acc,
                                arith::FastMathFlags fastMathFlags) {}

LogicalResult ReductionOp::verify() {}

// MaskableOpInterface methods.

/// Returns the mask type expected by this operation.
Type ReductionOp::getExpectedMaskType() {}

Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
                                         OpBuilder &builder, Location loc,
                                         Value vector) {}

std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {}

namespace {
struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {};
} // namespace

void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// ContractionOp
//===----------------------------------------------------------------------===//

void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
                                  Value lhs, Value rhs, Value acc,
                                  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
                                  ArrayRef<IteratorType> iteratorTypes) {}

void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
                                  Value lhs, Value rhs, Value acc,
                                  ArrayAttr indexingMaps,
                                  ArrayAttr iteratorTypes) {}

void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
                                  Value lhs, Value rhs, Value acc,
                                  ArrayAttr indexingMaps,
                                  ArrayAttr iteratorTypes, CombiningKind kind) {}

ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {}

void ContractionOp::print(OpAsmPrinter &p) {}

static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
                         const std::vector<std::pair<int64_t, int64_t>> &map) {}

static LogicalResult verifyOutputShape(
    ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
    Type resType,
    const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
    const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {}

LogicalResult ContractionOp::verify() {}

// MaskableOpInterface methods.

/// Returns the mask type expected by this operation. Mostly used for
/// verification purposes. It requires the operation to be vectorized."
Type ContractionOp::getExpectedMaskType() {}

SmallVector<StringRef> ContractionOp::getTraitAttrNames() {}

static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {}

static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
          IteratorType targetIteratorType, MLIRContext *context) {}

void ContractionOp::getIterationBounds(
    SmallVectorImpl<int64_t> &iterationBounds) {}

void ContractionOp::getIterationIndexMap(
    std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {}

std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {}

std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {}

std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {}

/// Return a fused vector::ContractionOp which represents a patterns such as:
///
/// ```mlir
///    %c0 = vector.constant 0: ...
///    %c = vector.contract %a, %b, %c0: ...
///    %e = add %c, %d: ...
/// ```
///
/// by:
///
/// ```mlir
///    %e = vector.contract %a, %b, %d: ...
/// ```
///
/// Return null if the canonicalization does not apply.
// TODO: This should be a folding of Add into Contract in core but while they
// live in different dialects, it is not possible without unnatural
// dependencies.
template <typename AddOpType>
struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {};

void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//

void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
                                     Value source) {}

LogicalResult vector::ExtractElementOp::verify() {}

OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//

void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
                              Value source, int64_t position) {}

void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
                              Value source, OpFoldResult position) {}

void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
                              Value source, ArrayRef<int64_t> position) {}

void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
                              Value source, ArrayRef<OpFoldResult> position) {}

LogicalResult
ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
                            ExtractOp::Adaptor adaptor,
                            SmallVectorImpl<Type> &inferredReturnTypes) {}

bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

LogicalResult vector::ExtractOp::verify() {}

template <typename IntType>
static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {}

/// Fold the result of chains of ExtractOp in place by simply concatenating the
/// positions.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {}

namespace {
/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
/// Compose TransposeOp permutations as we walk back.
/// This helper class keeps an updated extraction position `extractPosition`
/// with extra trailing sentinels.
/// The sentinels encode the internal transposition status of the result vector.
/// As we iterate, extractPosition is permuted and updated.
class ExtractFromInsertTransposeChainState {};
} // namespace

ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
    ExtractOp e)
    :{}

// Case 1. If we hit a transpose, just compose the map and iterate.
// Invariant: insert + transpose do not change rank, we can always compose.
LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {}

// Case 2: the insert position matches extractPosition exactly, early return.
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
    Value &res) {}

/// Case 3: if inserted position is a prefix of extractPosition,
/// extract a portion of the source of the insertion.
/// This method updates the internal state.
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {}

/// Try to fold in place to extract(source, extractPosition) and return the
/// folded result. Return null if folding is not possible (e.g. due to an
/// internal tranposition in the result).
Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
    Value source) {}

/// Iterate over producing insert and transpose ops until we find a fold.
Value ExtractFromInsertTransposeChainState::fold() {}

/// Returns true if the operation has a 0-D vector type operand or result.
static bool hasZeroDimVectors(Operation *op) {}

/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {}

// Fold extractOp with source coming from ShapeCast op.
static Value foldExtractFromShapeCast(ExtractOp extractOp) {}

/// Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromExtractStrided(ExtractOp extractOp) {}

/// Fold extract_op fed from a chain of insertStridedSlice ops.
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {}

/// Try to fold the extraction of a scalar from a vector defined by
/// vector.from_elements. E.g.:
///
/// %0 = vector.from_elements %a, %b : vector<2xf32>
/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
/// ==> fold to %a
static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {}

OpFoldResult ExtractOp::fold(FoldAdaptor) {}

namespace {

// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {};

// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {};

// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
class ExtractOpNonSplatConstantFolder final
    : public OpRewritePattern<ExtractOp> {};

// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {};

// Folds extract(shape_cast(..)) into shape_cast when the total element count
// does not change.
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
                                                  PatternRewriter &rewriter) {}

/// Try to canonicalize the extraction of a subvector from a vector defined by
/// vector.from_elements. E.g.:
///
/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
                                          PatternRewriter &rewriter) {}
} // namespace

void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {}

static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
                                       SmallVectorImpl<int64_t> &results) {}

//===----------------------------------------------------------------------===//
// FmaOp
//===----------------------------------------------------------------------===//

std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {}

//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//

/// Rewrite a vector.from_elements into a vector.splat if all elements are the
/// same SSA value. E.g.:
///
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
/// ==> rewrite to vector.splat %a : vector<3xf32>
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
                                                PatternRewriter &rewriter) {}

void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//

/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
static llvm::SetVector<int64_t>
computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
                           ArrayRef<int64_t> dstShape) {}

llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {}

/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
/// `broadcastedDims` dimensions in the dstShape are broadcasted.
/// This requires (and asserts) that the broadcast is free of dim-1
/// broadcasting.
/// Since vector.broadcast only allows expanding leading dimensions, an extra
/// vector.transpose may be inserted to make the broadcast possible.
/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
/// the helper will assert. This means:
///   1. `dstShape` must not be empty.
///   2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
///   2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
//       must match the `value` shape.
Value BroadcastOp::createOrFoldBroadcastOp(
    OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
    const llvm::SetVector<int64_t> &broadcastedDims) {}

BroadcastableToResult mlir::vector::isBroadcastableTo(
    Type srcType, VectorType dstVectorType,
    std::pair<VectorDim, VectorDim> *mismatchingDims) {}

LogicalResult BroadcastOp::verify() {}

OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {}

namespace {

// Fold broadcast1(broadcast2(x)) into broadcast1(x).
struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {};
} // namespace

void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// ShuffleOp
//===----------------------------------------------------------------------===//

LogicalResult ShuffleOp::verify() {}

LogicalResult
ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
                            ShuffleOp::Adaptor adaptor,
                            SmallVectorImpl<Type> &inferredReturnTypes) {}

template <typename T>
static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {}

OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {}

namespace {

// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
// to a broadcast.
struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {};

/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {};

/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
/// vector.interleave.
class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {};

} // namespace

void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// InsertElementOp
//===----------------------------------------------------------------------===//

void InsertElementOp::build(OpBuilder &builder, OperationState &result,
                            Value source, Value dest) {}

LogicalResult InsertElementOp::verify() {}

OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//

void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
                             Value source, Value dest, int64_t position) {}

void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
                             Value source, Value dest, OpFoldResult position) {}

void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
                             Value source, Value dest,
                             ArrayRef<int64_t> position) {}

void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
                             Value source, Value dest,
                             ArrayRef<OpFoldResult> position) {}

LogicalResult InsertOp::verify() {}

namespace {

// If insertOp is only inserting unit dimensions it can be transformed to a
// broadcast.
class InsertToBroadcast final : public OpRewritePattern<InsertOp> {};

/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {};

// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {};

} // namespace

void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {}

// Eliminates insert operations that produce values identical to their source
// value. This happens when the source and destination vectors have identical
// sizes.
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// InsertStridedSliceOp
//===----------------------------------------------------------------------===//

void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
                                 Value source, Value dest,
                                 ArrayRef<int64_t> offsets,
                                 ArrayRef<int64_t> strides) {}

// TODO: Should be moved to Tablegen ConfinedAttr attributes.
template <typename OpType>
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
                                                        ArrayAttr arrayAttr,
                                                        ArrayRef<int64_t> shape,
                                                        StringRef attrName) {}

// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
// interval. If `halfOpen` is true then the admissible interval is [min, max).
// Otherwise, the admissible interval is [min, max].
template <typename OpType>
static LogicalResult
isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
                                  int64_t max, StringRef attrName,
                                  bool halfOpen = true) {}

// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
// interval. If `halfOpen` is true then the admissible interval is [min, max).
// Otherwise, the admissible interval is [min, max].
template <typename OpType>
static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
                                  ArrayRef<int64_t> shape, StringRef attrName,
                                  bool halfOpen = true, int64_t min = 0) {}

// Returns true if, for all indices i = 0..shape.size()-1, val is in the
// [min, max} interval:
//   val = `arrayAttr1[i]` + `arrayAttr2[i]`,
// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
// the admissible interval is [min, max].
template <typename OpType>
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
    OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
    ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
    bool halfOpen = true, int64_t min = 1) {}

static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
                                  MLIRContext *context) {}

LogicalResult InsertStridedSliceOp::verify() {}

namespace {
/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
class FoldInsertStridedSliceSplat final
    : public OpRewritePattern<InsertStridedSliceOp> {};

/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
/// to dst.
class FoldInsertStridedSliceOfExtract final
    : public OpRewritePattern<InsertStridedSliceOp> {};

// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
// ConstantOp.
class InsertStridedSliceConstantFolder final
    : public OpRewritePattern<InsertStridedSliceOp> {};

} // namespace

void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
    RewritePatternSet &results, MLIRContext *context) {}

OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// OuterProductOp
//===----------------------------------------------------------------------===//

/// Build an op without mask, use the type of `acc` as the return type.
void OuterProductOp::build(OpBuilder &builder, OperationState &result,
                           Value lhs, Value rhs, Value acc) {}

void OuterProductOp::print(OpAsmPrinter &p) {}

ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {}

LogicalResult OuterProductOp::verify() {}

// MaskableOpInterface methods.

/// Returns the mask type expected by this operation. Mostly used for
/// verification purposes. It requires the operation to be vectorized."
Type OuterProductOp::getExpectedMaskType() {}

//===----------------------------------------------------------------------===//
// ExtractStridedSliceOp
//===----------------------------------------------------------------------===//

// Inference works as follows:
//   1. Add 'sizes' from prefix of dims in 'offsets'.
//   2. Add sizes from 'vectorType' for remaining dims.
// Scalable flags are inherited from 'vectorType'.
static Type inferStridedSliceOpResultType(VectorType vectorType,
                                          ArrayAttr offsets, ArrayAttr sizes,
                                          ArrayAttr strides) {}

void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
                                  Value source, ArrayRef<int64_t> offsets,
                                  ArrayRef<int64_t> sizes,
                                  ArrayRef<int64_t> strides) {}

LogicalResult ExtractStridedSliceOp::verify() {}

// When the source of ExtractStrided comes from a chain of InsertStrided ops try
// to use the source of the InsertStrided ops if we can detect that the
// extracted vector is a subset of one of the vector inserted.
static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {}

OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {}

void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {}

namespace {

// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
// ConstantMaskOp.
class StridedSliceConstantMaskFolder final
    : public OpRewritePattern<ExtractStridedSliceOp> {};

// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
class StridedSliceSplatConstantFolder final
    : public OpRewritePattern<ExtractStridedSliceOp> {};

// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
// ConstantOp.
class StridedSliceNonSplatConstantFolder final
    : public OpRewritePattern<ExtractStridedSliceOp> {};

// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
// BroadcastOp(ExtractStrideSliceOp).
class StridedSliceBroadcast final
    : public OpRewritePattern<ExtractStridedSliceOp> {};

/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {};

/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
/// slice is contiguous, into extract and shape_cast.
///
/// Example:
///     Before:
///         %1 = vector.extract_strided_slice %arg0 {
///                offsets = [0, 0, 0, 0, 0],
///                sizes = [1, 1, 1, 1, 8],
///                strides = [1, 1, 1, 1, 1]
///              } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
///     After:
///         %0 = vector.extract %arg0[0, 0, 0, 0]
///                : vector<8xi8> from vector<8x1x1x2x8xi8>
///         %1 = vector.shape_cast %0
///                : vector<8xi8> to vector<1x1x1x1x8xi8>
///
class ContiguousExtractStridedSliceToExtract final
    : public OpRewritePattern<ExtractStridedSliceOp> {};

} // namespace

void ExtractStridedSliceOp::getCanonicalizationPatterns(
    RewritePatternSet &results, MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// TransferReadOp
//===----------------------------------------------------------------------===//

/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
                           VectorType vectorType, Value source,
                           ValueRange indices, AffineMapAttr permutationMapAttr,
                           /*optional*/ ArrayAttr inBoundsAttr) {}

/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
                           VectorType vectorType, Value source,
                           ValueRange indices, AffineMap permutationMap,
                           std::optional<ArrayRef<bool>> inBounds) {}

/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
                           VectorType vectorType, Value source,
                           ValueRange indices, Value padding,
                           std::optional<ArrayRef<bool>> inBounds) {}

/// 4. Builder that sets padding to zero and permutation map to
/// 'getMinorIdentityMap'.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
                           VectorType vectorType, Value source,
                           ValueRange indices,
                           std::optional<ArrayRef<bool>> inBounds) {}

template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
                                          EmitFun emitOpError) {}

static LogicalResult
verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
                 VectorType vectorType, VectorType maskType,
                 VectorType inferredMaskType, AffineMap permutationMap,
                 ArrayAttr inBounds) {}

static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {}

void TransferReadOp::print(OpAsmPrinter &p) {}

VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
                                                 AffineMap permMap) {}

ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {}

LogicalResult TransferReadOp::verify() {}

// MaskableOpInterface methods.

/// Returns the mask type expected by this operation. Mostly used for
/// verification purposes. It requires the operation to be vectorized."
Type TransferReadOp::getExpectedMaskType() {}

template <typename TransferOp>
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {}

template <typename TransferOp>
static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {}

template <typename TransferOp>
static LogicalResult foldTransferFullMask(TransferOp op) {}

///  ```
///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
///    : vector<1x4xf32>, tensor<4x4xf32>
///  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
///    : tensor<4x4xf32>, vector<1x4xf32>
///  ```
///  -> Folds into
///  ```
///  %v0
///  ```
static Value foldRAW(TransferReadOp readOp) {}

OpFoldResult TransferReadOp::fold(FoldAdaptor) {}

std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {}

void TransferReadOp::getEffects(
    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
        &effects) {}

Speculation::Speculatability TransferReadOp::getSpeculatability() {}

namespace {
/// Store to load forwarding for transfer operations with permuation maps.
/// Even if the permutation maps are different we can still propagate the store
/// into the load if the size of the dimensions read and written match. Then we
/// can replace the transfer_read + transfer_write by vector.broadcast and
/// vector.transpose.
/// Example:
/// ```
/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
///  {in_bounds = [true, true],
///   permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
///   vector<4x1xf32>, tensor<4x4x4xf32>
///  %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
///   {in_bounds = [true, true, true, true],
///   permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
///   tensor<4x4x4xf32>, vector<1x100x4x5xf32>
/// ```
/// To:
/// ```
/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
/// %r = vector.transpose %0, [3, 0, 2, 1] :
///   vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
/// ```
struct TransferReadAfterWriteToBroadcast
    : public OpRewritePattern<TransferReadOp> {};
} // namespace

void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// TransferWriteOp
//===----------------------------------------------------------------------===//

/// 1. Builder with type inference.
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                            Value vector, Value dest, ValueRange indices,
                            AffineMapAttr permutationMapAttr,
                            /*optional*/ Value mask,
                            /*optional*/ ArrayAttr inBoundsAttr) {}

/// 2. Builder with type inference that sets an empty mask (variant with attrs).
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                            Value vector, Value dest, ValueRange indices,
                            AffineMapAttr permutationMapAttr,
                            /*optional*/ ArrayAttr inBoundsAttr) {}

/// 3. Builder with type inference that sets an empty mask (variant without
/// attrs)
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                            Value vector, Value dest, ValueRange indices,
                            AffineMap permutationMap,
                            std::optional<ArrayRef<bool>> inBounds) {}

/// 4. Builder with type inference that sets an empty mask and sets permutation
///    map to 'getMinorIdentityMap'.
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                            Value vector, Value dest, ValueRange indices,
                            std::optional<ArrayRef<bool>> inBounds) {}

ParseResult TransferWriteOp::parse(OpAsmParser &parser,
                                   OperationState &result) {}

void TransferWriteOp::print(OpAsmPrinter &p) {}

LogicalResult TransferWriteOp::verify() {}

// MaskableOpInterface methods.

/// Returns the mask type expected by this operation. Mostly used for
/// verification purposes.
Type TransferWriteOp::getExpectedMaskType() {}

/// Fold:
/// ```
///    %t1 = ...
///    %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
///      tensor<static_sizesxf32>, vector<static_sizesxf32>
///    %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
///      vector<static_sizesxf32>, tensor<static_sizesxf32>
/// ```
///
/// into:
///
/// ```
///    %t0
/// ```
///
/// The producer of t1 may or may not be DCE'd depending on whether it is a
/// block argument or has side effects.
static LogicalResult foldReadInitWrite(TransferWriteOp write,
                                       ArrayRef<Attribute>,
                                       SmallVectorImpl<OpFoldResult> &results) {}

static bool checkSameValueWAR(vector::TransferReadOp read,
                              vector::TransferWriteOp write) {}
/// Fold transfer_write write after read:
/// ```
///    %t0 = ...
///    %v = vector.transfer_read %t0[%c0...] :
///      tensor<static_sizesxf32>, vector<static_sizesxf32>
///    %t1 = vector.transfer_write %v, %t0[%c0...] :
///      vector<static_sizesxf32>, tensor<static_sizesxf32>
/// ```
///
/// into:
///
/// ```
///    %t0
/// ```
static LogicalResult foldWAR(TransferWriteOp write,
                             SmallVectorImpl<OpFoldResult> &results) {}

LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
                                    SmallVectorImpl<OpFoldResult> &results) {}

std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {}

void TransferWriteOp::getEffects(
    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
        &effects) {}

Speculation::Speculatability TransferWriteOp::getSpeculatability() {}

namespace {
/// Remove dead transfer write from the SSA chain so that it an be eliminated by
/// DCE
/// ```
///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
///    : vector<1x4xf32>, tensor<4x4xf32>
///  %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
///    : vector<1x4xf32>, tensor<4x4xf32>
///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
///    : vector<1x4xf32>, tensor<4x4xf32>
/// ```
///
/// into:
///
/// ```
///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
///    : vector<1x4xf32>, tensor<4x4xf32>
///  %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
///    : vector<1x4xf32>, tensor<4x4xf32>
///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
///    : vector<1x4xf32>, tensor<4x4xf32>
/// ```
///
/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
/// any other uses.
class FoldWaw final : public OpRewritePattern<TransferWriteOp> {};

/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
/// overwritten and inserted into another tensor. After this rewrite, the
/// operations bufferize in-place since all of them work on the same slice.
///
/// For example:
/// ```mlir
///   %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
///        : vector<8x16xf32>, tensor<8x16xf32>
///   %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
///        : tensor<8x16xf32> to tensor<?x?xf32>
///   %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
///        : tensor<?x?xf32> into tensor<27x37xf32>
/// ```
/// folds to
/// ```mlir
///   %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
///        : tensor<27x37xf32> to tensor<?x?xf32>
///   %1 = vector.transfer_write %vec, %0[%c0, %c0]
///        : vector<8x16xf32>, tensor<?x?xf32>
///   %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
///        : tensor<?x?xf32> into tensor<27x37xf32>
/// ```
struct SwapExtractSliceOfTransferWrite
    : public OpRewritePattern<tensor::InsertSliceOp> {};

} // namespace

void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//

static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
                                                 VectorType vecTy,
                                                 MemRefType memRefTy) {}

LogicalResult vector::LoadOp::verify() {}

OpFoldResult LoadOp::fold(FoldAdaptor) {}

//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//

LogicalResult vector::StoreOp::verify() {}

LogicalResult StoreOp::fold(FoldAdaptor adaptor,
                            SmallVectorImpl<OpFoldResult> &results) {}

//===----------------------------------------------------------------------===//
// MaskedLoadOp
//===----------------------------------------------------------------------===//

LogicalResult MaskedLoadOp::verify() {}

namespace {
class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {};
} // namespace

void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {}

OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {}

//===----------------------------------------------------------------------===//
// MaskedStoreOp
//===----------------------------------------------------------------------===//

LogicalResult MaskedStoreOp::verify() {}

namespace {
class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {};
} // namespace

void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                MLIRContext *context) {}

LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
                                  SmallVectorImpl<OpFoldResult> &results) {}

//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//

LogicalResult GatherOp::verify() {}

// MaskableOpInterface methods.

/// Returns the mask type expected by this operation. Mostly used for
/// verification purposes. It requires the operation to be vectorized."
Type GatherOp::getExpectedMaskType() {}

std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {}

namespace {
class GatherFolder final : public OpRewritePattern<GatherOp> {};
} // namespace

void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//

LogicalResult ScatterOp::verify() {}

namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> {};
} // namespace

void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// ExpandLoadOp
//===----------------------------------------------------------------------===//

LogicalResult ExpandLoadOp::verify() {}

namespace {
class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {};
} // namespace

void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// CompressStoreOp
//===----------------------------------------------------------------------===//

LogicalResult CompressStoreOp::verify() {}

namespace {
class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {};
} // namespace

void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// ShapeCastOp
//===----------------------------------------------------------------------===//

/// Returns true if each element of 'a' is equal to the product of a contiguous
/// sequence of the elements of 'b'. Returns false otherwise.
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {}

static LogicalResult verifyVectorShapeCast(Operation *op,
                                           VectorType sourceVectorType,
                                           VectorType resultVectorType) {}

LogicalResult ShapeCastOp::verify() {}

OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {}

namespace {
// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {};

/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
///
///   vector<4x1x1xi1> --> vector<4x1>
///
static VectorType trimTrailingOneDims(VectorType oldType) {}

/// Folds qualifying shape_cast(create_mask) into a new create_mask
///
/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
/// dimension. If the input vector comes from `vector.create_mask` for which
/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
/// to fold shape_cast into create_mask.
///
/// BEFORE:
///    %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
///    %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
/// AFTER:
///    %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
class ShapeCastCreateMaskFolderTrailingOneDim final
    : public OpRewritePattern<ShapeCastOp> {};

/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
/// This only applies when the shape of the broadcast source
/// 1. is a suffix of the shape of the result (i.e. when broadcast without
///    reshape is expressive enough to capture the result in a single op), or
/// 2. has the same element count as the shape cast result.
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {};

} // namespace

void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// VectorBitCastOp
//===----------------------------------------------------------------------===//

LogicalResult BitCastOp::verify() {}

OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// TypeCastOp
//===----------------------------------------------------------------------===//

static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {}

/// Build the canonical memRefType with a single vector.
/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
void TypeCastOp::build(OpBuilder &builder, OperationState &result,
                       Value source) {}

LogicalResult TypeCastOp::verify() {}

//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//

void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
                                Value vector, ArrayRef<int64_t> permutation) {}

OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {}

LogicalResult vector::TransposeOp::verify() {}

std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {}

namespace {

// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {};

// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
struct FoldTransposedScalarBroadcast final
    : public OpRewritePattern<vector::TransposeOp> {};

// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {};

/// Folds transpose(create_mask) into a new transposed create_mask.
class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {};

} // namespace

void vector::TransposeOp::getCanonicalizationPatterns(
    RewritePatternSet &results, MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// ConstantMaskOp
//===----------------------------------------------------------------------===//

void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
                           VectorType type, ConstantMaskKind kind) {}

LogicalResult ConstantMaskOp::verify() {}

bool ConstantMaskOp::isAllOnesMask() {}

//===----------------------------------------------------------------------===//
// CreateMaskOp
//===----------------------------------------------------------------------===//

void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
                         VectorType type,
                         ArrayRef<OpFoldResult> mixedOperands) {}

LogicalResult CreateMaskOp::verify() {}

namespace {

/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
///
/// Ex 1:
///   %c2 = arith.constant 2 : index
///   %c3 = arith.constant 3 : index
///   %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
/// Becomes:
///    vector.constant_mask [3, 2] : vector<4x3xi1>
///
/// Ex 2:
///   %c_neg_1 = arith.constant -1 : index
///   %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
/// becomes:
///   vector.constant_mask [0] : vector<[8]xi1>
///
/// Ex 3:
///   %c8 = arith.constant 8 : index
///   %c16 = arith.constant 16 : index
///   %0 = vector.vscale
///   %1 = arith.muli %0, %c16 : index
///   %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
/// becomes:
///   %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {};

} // namespace

void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// MaskOp
//===----------------------------------------------------------------------===//

void MaskOp::build(
    OpBuilder &builder, OperationState &result, Value mask,
    Operation *maskableOp,
    function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {}

void MaskOp::build(
    OpBuilder &builder, OperationState &result, TypeRange resultTypes,
    Value mask, Operation *maskableOp,
    function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {}

void MaskOp::build(
    OpBuilder &builder, OperationState &result, TypeRange resultTypes,
    Value mask, Value passthru, Operation *maskableOp,
    function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {}

ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {}

void mlir::vector::MaskOp::print(OpAsmPrinter &p) {}

void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {}

LogicalResult MaskOp::verify() {}

/// Folds vector.mask ops with an all-true mask.
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
                           SmallVectorImpl<OpFoldResult> &results) {}

// Elides empty vector.mask operations with or without return values. Propagates
// the yielded values by the vector.yield terminator, if any, or erases the op,
// otherwise.
class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {};

void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {}

// MaskingOpInterface definitions.

/// Returns the operation masked by this 'vector.mask'.
Operation *MaskOp::getMaskableOp() {}

/// Returns true if 'vector.mask' has a passthru value.
bool MaskOp::hasPassthru() {}

//===----------------------------------------------------------------------===//
// ScanOp
//===----------------------------------------------------------------------===//

LogicalResult ScanOp::verify() {}

void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {}

//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//

OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// StepOp
//===----------------------------------------------------------------------===//

OpFoldResult StepOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// WarpExecuteOnLane0Op
//===----------------------------------------------------------------------===//

void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {}

ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
                                        OperationState &result) {}

void WarpExecuteOnLane0Op::getSuccessorRegions(
    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}

void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
                                 TypeRange resultTypes, Value laneId,
                                 int64_t warpSize) {}

void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
                                 TypeRange resultTypes, Value laneId,
                                 int64_t warpSize, ValueRange args,
                                 TypeRange blockArgTypes) {}

/// Helper check if the distributed vector type is consistent with the expanded
/// type and distributed size.
static LogicalResult verifyDistributedType(Type expanded, Type distributed,
                                           int64_t warpSize, Operation *op) {}

LogicalResult WarpExecuteOnLane0Op::verify() {}

bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {}

Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
                                       CombiningKind kind, Value v1, Value acc,
                                       arith::FastMathFlagsAttr fastmath,
                                       Value mask) {}

//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//

/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
/// as masked operation.
void mlir::vector::createMaskOpRegion(OpBuilder &builder,
                                      Operation *maskableOp) {}

/// Creates a vector.mask operation around a maskable operation. Returns the
/// vector.mask operation if the mask provided is valid. Otherwise, returns
/// the maskable operation itself.
Operation *mlir::vector::maskOperation(OpBuilder &builder,
                                       Operation *maskableOp, Value mask,
                                       Value passthru) {}

/// Creates a vector select operation that picks values from `newValue` or
/// `passthru` for each result vector lane based on `mask`. This utility is used
/// to propagate the pass-thru value of vector.mask or for cases where only the
/// pass-thru value propagation is needed. VP intrinsics do not support
/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
/// usually able to match op + select patterns and fold them into a native
/// target instructions.
Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
                                   Value newValue, Value passthru) {}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"