#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
usingnamespacemlir;
usingnamespacemlir::linalg;
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
int64_t dim) { … }
static Value getSlice(OpBuilder &b, Location loc, Value source,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) { … }
Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
int64_t dim) { … }
OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
int64_t dim) { … }
RegionBuilderFn;
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<NamedAttribute> attrs,
RegionBuilderFn regionBuilder) { … }
static void buildStructuredOp(OpBuilder &b, OperationState &state,
std::optional<TypeRange> resultTensorTypes,
ValueRange inputs, ValueRange outputs,
ArrayRef<NamedAttribute> attributes,
RegionBuilderFn regionBuilder) { … }
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<Type> &outputTypes,
bool addOperandSegmentSizes = true) { … }
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
ValueRange outputs) { … }
static ParseResult parseNamedStructuredOpRegion(
OpAsmParser &parser, Region ®ion, unsigned numRegionArgs,
TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
RegionBuilderFn regionBuilder) { … }
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes) { … }
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result,
unsigned numRegionArgs,
RegionBuilderFn regionBuilder) { … }
static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes) { … }
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
ValueRange inputs, ValueRange outputs) { … }
namespace {
class RegionBuilderHelper { … };
}
namespace {
struct EraseSelfCopy : OpRewritePattern<CopyOp> { … };
}
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
namespace {
template <typename TensorReshapeOp>
struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> { … };
struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> { … };
struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> { … };
struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> { … };
static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
tensor::PackOp packOp) { … }
struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> { … };
struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> { … };
struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> { … };
struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> { … };
}
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
static void buildGenericRegion(
OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs,
ValueRange outputs,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { … }
void GenericOp::getAsmBlockArgumentNames(Region ®ion,
OpAsmSetValueNameFn setNameFn) { … }
void GenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) { … }
void GenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) { … }
void GenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) { … }
void GenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<utils::IteratorType> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) { … }
void GenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<utils::IteratorType> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) { … }
void GenericOp::print(OpAsmPrinter &p) { … }
ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) { … }
static void getGenericEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
LinalgOp linalgOp) { … }
void GenericOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) { … }
LogicalResult GenericOp::verify() { … }
namespace {
template <typename OpTy>
struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> { … };
}
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { … }
static ParseResult parseDstStyleOp(
OpAsmParser &parser, OperationState &result,
function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
nullptr) { … }
void MapOp::getAsmBlockArgumentNames(Region ®ion,
OpAsmSetValueNameFn setNameFn) { … }
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { … }
void MapOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) { … }
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
const OperationName &payloadOpName,
const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
bool initFirst = false) { … }
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { … }
static Operation *findPayloadOp(Block *body, bool initFirst = false) { … }
void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { … }
void MapOp::print(OpAsmPrinter &p) { … }
LogicalResult MapOp::verify() { … }
SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() { … }
ArrayAttr MapOp::getIndexingMaps() { … }
void MapOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) { … }
void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
OpAsmSetValueNameFn setNameFn) { … }
void ReduceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
void ReduceOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
ValueRange inits, ArrayRef<int64_t> dimensions,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) { … }
SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() { … }
ArrayAttr ReduceOp::getIndexingMaps() { … }
void ReduceOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) { … }
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
NamedAttrList &attributes,
StringRef attributeName) { … }
ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { … }
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
ArrayRef<int64_t> attributeValue) { … }
void ReduceOp::print(OpAsmPrinter &p) { … }
LogicalResult ReduceOp::verify() { … }
static void buildIdentityRegion(OpBuilder &builder, Location loc,
Region ®ion, ValueRange inputs,
ValueRange outputs) { … }
void TransposeOp::build(::mlir::OpBuilder &builder,
::mlir::OperationState &result, Value input, Value init,
DenseI64ArrayAttr permutation,
ArrayRef<NamedAttribute> attributes) { … }
void TransposeOp::build(::mlir::OpBuilder &builder,
::mlir::OperationState &result, Value input, Value init,
ArrayRef<int64_t> permutation,
ArrayRef<NamedAttribute> attributes) { … }
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { … }
void TransposeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
void TransposeOp::print(OpAsmPrinter &p) { … }
LogicalResult TransposeOp::verify() { … }
SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() { … }
ArrayAttr TransposeOp::getIndexingMaps() { … }
void TransposeOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) { … }
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &result) { … }
struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> { … };
struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> { … };
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
void BroadcastOp::build(::mlir::OpBuilder &builder,
::mlir::OperationState &result, Value input, Value init,
DenseI64ArrayAttr dimensions,
ArrayRef<NamedAttribute> attributes) { … }
void BroadcastOp::build(::mlir::OpBuilder &builder,
::mlir::OperationState &result, Value input, Value init,
ArrayRef<int64_t> dimensions,
ArrayRef<NamedAttribute> attributes) { … }
ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) { … }
void BroadcastOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
void BroadcastOp::print(OpAsmPrinter &p) { … }
LogicalResult BroadcastOp::verify() { … }
SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() { … }
ArrayAttr BroadcastOp::getIndexingMaps() { … }
void BroadcastOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) { … }
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) { … }
void linalg::YieldOp::print(OpAsmPrinter &p) { … }
ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) { … }
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { … }
LogicalResult linalg::YieldOp::verify() { … }
LogicalResult IndexOp::verify() { … }
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
unsigned rank,
MLIRContext *context) { … }
SmallVector<AffineExpr, 4>
mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
MLIRContext *context) { … }
SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
ArrayRef<AffineExpr> b) { … }
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) { … }
std::string mlir::linalg::generateLibraryCallName(Operation *op) { … }
namespace {
struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> { … };
struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> { … };
static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) { … }
static void createNewOperandWithStaticSizes(
Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
bool &changeNeeded) { … }
struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> { … };
}
LogicalResult SoftmaxOp::verify() { … }
SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) { … }
SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() { … }
FailureOr<TilingResult>
SoftmaxOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) { … }
LogicalResult SoftmaxOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) { … }
LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { … }
LogicalResult
SoftmaxOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) { … }
void SoftmaxOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) { … }
static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
int64_t dim, bool allParallel = false) { … }
template <typename T>
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
int64_t dim) { … }
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
Value max, Value output, int64_t dim) { … }
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
Value denominator, Value output, int64_t dim) { … }
FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) { … }
LogicalResult WinogradFilterTransformOp::verify() { … }
SmallVector<Range>
WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) { … }
SmallVector<utils::IteratorType>
WinogradFilterTransformOp::getLoopIteratorTypes() { … }
LogicalResult WinogradFilterTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) { … }
FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) { … }
LogicalResult WinogradInputTransformOp::verify() { … }
SmallVector<Range>
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { … }
SmallVector<utils::IteratorType>
WinogradInputTransformOp::getLoopIteratorTypes() { … }
LogicalResult WinogradInputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) { … }
FailureOr<TilingResult>
WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) { … }
LogicalResult WinogradOutputTransformOp::verify() { … }
SmallVector<Range>
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) { … }
SmallVector<utils::IteratorType>
WinogradOutputTransformOp::getLoopIteratorTypes() { … }
LogicalResult WinogradOutputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) { … }
FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) { … }
void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const { … }
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) { … }