llvm/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

//===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Vector dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
#define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"

// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/Vector/IR/VectorEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc"

namespace mlir {
class MLIRContext;
class RewritePatternSet;

namespace arith {
enum class AtomicRMWKind : uint64_t;
} // namespace arith

namespace vector {
class ContractionOp;
class TransferReadOp;
class TransferWriteOp;
class VectorDialect;

namespace detail {
struct BitmaskEnumStorage;
} // namespace detail

/// Predefined constant_mask kinds.
enum class ConstantMaskKind {};

/// Default callback to build a region with a 'vector.yield' terminator with no
/// arguments.
void buildTerminatedBody(OpBuilder &builder, Location loc);

/// Return whether `srcType` can be broadcast to `dstVectorType` under the
/// semantics of the `vector.broadcast` op.
enum class BroadcastableToResult {};

struct VectorDim {};
BroadcastableToResult
isBroadcastableTo(Type srcType, VectorType dstVectorType,
                  std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr);

/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
                                                    PatternBenefit benefit = 1);

/// Collect a set of patterns that fold arithmetic extension on floating point
/// into vector contract for the backends with native support.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);

/// Collect a set of patterns that fold elementwise op on vectors to the vector
/// dialect.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);

/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);

/// Returns an integer array attribute containing the given values using
/// the integer type required for subscripts in the vector dialect.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);

/// Returns the value obtained by reducing the vector into a scalar using the
/// operation kind associated with a binary AtomicRMWKind op.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder,
                           Location loc, Value vector);

/// Build the default minor identity map suitable for a vector transfer. This
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
/// rank of the identity map must take the vector element type into account.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
                                      VectorType vectorType);

/// Return true if the transfer_write fully writes the data accessed by the
/// transfer_read.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read);

/// Return true if the write op fully over-write the priorWrite transfer_write
/// op.
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite);

/// Return true if we can prove that the transfer operations access disjoint
/// memory, without requring the accessed tensor/memref to be the same.
///
/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
/// via ValueBoundsOpInterface.
bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
                               VectorTransferOpInterface transferB,
                               bool testDynamicValueUsingBounds = false);

/// Return true if we can prove that the transfer operations access disjoint
/// memory, requiring the operations to access the same tensor/memref.
///
/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
/// via ValueBoundsOpInterface.
bool isDisjointTransferSet(VectorTransferOpInterface transferA,
                           VectorTransferOpInterface transferB,
                           bool testDynamicValueUsingBounds = false);

/// Returns the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
                         Value v1, Value acc,
                         arith::FastMathFlagsAttr fastmath = nullptr,
                         Value mask = nullptr);

/// Returns true if `attr` has "parallel" iterator type semantics.
inline bool isParallelIterator(Attribute attr) {}

/// Returns true if `attr` has "reduction" iterator type semantics.
inline bool isReductionIterator(Attribute attr) {}

/// Returns the integer numbers in `values`. `values` are expected to be
/// constant operations.
SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values);

/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
/// be constant operations.
SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults);

/// Convert `foldResults` into Values. Integer attributes are converted to
/// constant op.
SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
                               ArrayRef<OpFoldResult> foldResults);

/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
/// `std::nullopt`.
std::optional<int64_t> getConstantVscaleMultiplier(Value value);

//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//

/// Infers the mask type for a transfer op given its vector type and
/// permutation map. The mask in a transfer op operation applies to the
/// tensor/buffer part of it and its type should match the vector shape
/// *before* any permutation or broadcasting. For example,
///
/// vecType = vector<1x2x3xf32>, permMap = affine_map<(d0, d1, d2) -> (d1, d0)>
///
/// Has inferred mask type:
///
/// maskType = vector<2x1xi1>
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap);

/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
/// as masked operation.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);

/// Creates a vector.mask operation around a maskable operation. Returns the
/// vector.mask operation if the mask provided is valid. Otherwise, returns the
/// maskable operation itself.
Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask,
                         Value passthru = Value());

/// Creates a vector select operation that picks values from `newValue` or
/// `passthru` for each result vector lane based on `mask`. This utility is used
/// to propagate the pass-thru value for masked-out or expeculatively executed
/// lanes. VP intrinsics do not support pass-thru values and every mask-out lane
/// is set to poison. LLVM backends are usually able to match op + select
/// patterns and fold them into a native target instructions.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue,
                     Value passthru);

} // namespace vector
} // namespace mlir

#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/IR/VectorDialect.h.inc"
#include "mlir/Dialect/Vector/IR/VectorOps.h.inc"

#endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H