#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <type_traits>
usingnamespacemlir;
usingnamespacemlir::linalg;
usingnamespacemlir::transform;
#define DEBUG_TYPE …
#define DBGS() …
#define DBGSNL() …
#define LDBG(X) …
template <typename PatternTy, typename... Args>
static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { … }
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
transform::TransformState &state, TransformOpInterface transformOp,
SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) { … }
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
transform::TransformState &state, TransformOpInterface transformOp,
SmallVector<OpFoldResult> &result, Value packedHandle) { … }
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
TransformState &state, TransformOpInterface &transformOp,
ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) { … }
void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
RewritePatternSet &patterns) { … }
void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
RewritePatternSet &patterns) { … }
void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
RewritePatternSet &patterns) { … }
void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
RewritePatternSet &patterns) { … }
void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
RewritePatternSet &patterns) { … }
void transform::BufferizeToAllocationOp::build(OpBuilder &b,
OperationState &result,
Value target,
Attribute memorySpace) { … }
void transform::BufferizeToAllocationOp::build(OpBuilder &b,
OperationState &result,
Value target,
int64_t memorySpace) { … }
namespace {
class NewOpsListener : public RewriterBase::ForwardingListener { … };
}
DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) { … }
void transform::BufferizeToAllocationOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
LogicalResult transform::BufferizeToAllocationOp::verify() { … }
DiagnosedSilenceableFailure
transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
DiagnosedSilenceableFailure
transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
transform::TransformRewriter &rewriter, TransformResults &transformResults,
TransformState &state) { … }
template <typename Range>
static LogicalResult applyTilingToAll(
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
unsigned numLoops, transform::TransformResults &transformResults,
function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
applyFn) { … }
DiagnosedSilenceableFailure
transform::FuseOp::apply(transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) { … }
LogicalResult transform::FuseOp::verify() { … }
void transform::FuseIntoContainingOp::build(OpBuilder &builder,
OperationState &result,
Value producerOp,
Value containingOp) { … }
static Operation *replaceForAllWithNewSignature(
RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
Operation *containingOp, TilingResult &tileAndFuseResult,
int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes) { … }
static std::tuple<SmallVector<Operation *>, Operation *>
tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp, Operation *containingOp) { … }
static SmallVector<Operation *>
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
Operation *containingOp) { … }
static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) { … }
bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() { … }
DiagnosedSilenceableFailure
transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) { … }
void transform::FuseIntoContainingOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
DiagnosedSilenceableFailure
transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure
transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure
transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
GenericOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
LogicalResult transform::InterchangeOp::verify() { … }
DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
transform::TransformRewriter &rewriter, tensor::PackOp target,
transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
transform::TransformRewriter &rewriter, tensor::UnPackOp target,
transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) { … }
void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
Value target, ArrayRef<StringRef> opNames) { … }
void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, Value target,
ArrayRef<StringRef> opNames) { … }
DiagnosedSilenceableFailure
transform::MatchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) { … }
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op,
Type targetType, Type lowSizeType, Type,
Type) { … }
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
Type &targetType, Type &lowSizeType,
Type &highSizeType,
Type &splitPointType) { … }
DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
transform::TransformRewriter &rewriter, LinalgOp target,
transform::ApplyToEachResultList &results, TransformState &state) { … }
void transform::MultiTileSizesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
LogicalResult transform::MultiTileSizesOp::verify() { … }
void transform::PackOp::build(OpBuilder &builder, OperationState &result,
Value target,
ArrayRef<OpFoldResult> mixedPackedSizes) { … }
SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() { … }
DiagnosedSilenceableFailure
transform::PackOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) { … }
void transform::PackOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
LogicalResult transform::PackGreedilyOp::verify() { … }
DiagnosedSilenceableFailure
PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) { … }
SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() { … }
void transform::PackGreedilyOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
LogicalResult transform::PackTransposeOp::verify() { … }
namespace {
enum class OuterOrInnerPerm { … };
}
template <typename RelayoutOpTy>
bool isValidPackingPermutation(
RelayoutOpTy op, ArrayRef<int64_t> permutation,
OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) { … }
DiagnosedSilenceableFailure
transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) { … }
void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
ArrayRef<int64_t> paddingDimensions,
ArrayRef<int64_t> padToMultipleOf,
ArrayRef<int64_t> packPaddings,
ArrayRef<Attribute> transposePaddings,
StringRef copyBackOp) { … }
void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
ArrayRef<int64_t> paddingDimensions,
ArrayRef<OpFoldResult> mixedPadToMultipleOf,
ArrayRef<int64_t> packPaddings,
ArrayRef<Attribute> transposePaddings,
StringRef copyBackOp) { … }
void PadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() { … }
DiagnosedSilenceableFailure
transform::PadOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) { … }
LogicalResult transform::PadOp::verify() { … }
DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) { … }
LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() { … }
void transform::HoistPadBuildPackingLoopNestOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
DiagnosedSilenceableFailure
transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
tensor::PadOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
LogicalResult transform::HoistPadOp::verify() { … }
DiagnosedSilenceableFailure
transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure
transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) { … }
void transform::ReplaceOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
LogicalResult transform::ReplaceOp::verify() { … }
DiagnosedSilenceableFailure
transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure
transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure
transform::RewriteInDestinationPassingStyleOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure
SplitOp::apply(transform::TransformRewriter &rewriter,
TransformResults &results, TransformState &state) { … }
void SplitOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { … }
void SplitOp::print(OpAsmPrinter &printer) { … }
LogicalResult SplitOp::verify() { … }
void transform::SplitReductionOp::build(
OpBuilder &builder, OperationState &result, Value target,
int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
bool useScalingAlgorithm, bool useAlloc) { … }
DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
transform::TransformRewriter &rewriter, LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
void transform::TileReductionUsingForOp::build(
OpBuilder &builder, OperationState &result, Value target,
ArrayRef<int64_t> staticTileSizes) { … }
DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
transform::TransformRewriter &rewriter, LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
void transform::TileReductionUsingForallOp::build(
OpBuilder &builder, OperationState &result, Value target,
ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
ArrayAttr mapping) { … }
DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
transform::TransformRewriter &rewriter, LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure
transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) { … }
LogicalResult transform::ContinuousTileSizesOp::verify() { … }
void transform::ContinuousTileSizesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
Type targetType, Type tile_sizes,
Type) { … }
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
Type &targetType,
Type &tileSizesType,
Type &chunkSizesType) { … }
void transform::TileUsingForOp::build(
OpBuilder &builder, OperationState &result, TypeRange loopTypes,
Value target, ArrayRef<int64_t> staticTileSizes,
ArrayRef<int64_t> interchange,
std::optional<ArrayRef<bool>> scalableSizes) { … }
void transform::TileUsingForOp::build(
OpBuilder &builder, OperationState &result, Value target,
ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
std::optional<ArrayRef<bool>> scalableSizes) { … }
void transform::TileUsingForOp::build(
OpBuilder &builder, OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
std::optional<ArrayRef<bool>> scalableSizes) { … }
void transform::TileUsingForOp::build(
OpBuilder &builder, OperationState &result, TypeRange loopTypes,
Value target, ArrayRef<OpFoldResult> mixedTileSizes,
ArrayRef<int64_t> interchange,
std::optional<ArrayRef<bool>> scalableSizes) { … }
LogicalResult transform::TileUsingForOp::verify() { … }
DiagnosedSilenceableFailure
transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) { … }
SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() { … }
void transform::TileUsingForOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
void transform::TileUsingForallOp::build(OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<int64_t> staticTileSizes,
transform::TileSizesSpec,
ArrayAttr mapping) { … }
void transform::TileUsingForallOp::build(OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedTileSizes,
transform::TileSizesSpec,
ArrayAttr mapping) { … }
void transform::TileUsingForallOp::build(OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<int64_t> staticNumThreads,
transform::NumThreadsSpec,
ArrayAttr mapping) { … }
void transform::TileUsingForallOp::build(OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedNumThreads,
transform::NumThreadsSpec,
ArrayAttr mapping) { … }
static SmallVector<OpFoldResult>
normalizeUpperBounds(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> steps) { … }
static SmallVector<Value> denormalizeIndVar(RewriterBase &rewriter,
Location loc, ValueRange ivs,
ArrayRef<OpFoldResult> lbs,
ArrayRef<OpFoldResult> steps) { … }
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
scf::ForallOp loop) { … }
DiagnosedSilenceableFailure transform::tileToForallOpImpl(
RewriterBase &rewriter, transform::TransformState &state,
TransformOpInterface transformOp, Operation *target,
ArrayRef<OpFoldResult> mixedNumThreads,
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
scf::SCFTilingResult &tilingResult) { … }
DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) { … }
void transform::TileUsingForallOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() { … }
SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() { … }
LogicalResult TileUsingForallOp::verify() { … }
void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) { … }
namespace {
struct VectorizationPattern : public RewritePattern { … };
}
DiagnosedSilenceableFailure
transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure transform::VectorizeOp::apply(
transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) { … }
void transform::VectorizeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() { … }
LogicalResult transform::VectorizeOp::verify() { … }
DiagnosedSilenceableFailure
transform::HoistRedundantVectorTransfersOp::applyToOne(
transform::TransformRewriter &rewriter, func::FuncOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure
transform::HoistRedundantVectorBroadcastsOp::applyToOne(
transform::TransformRewriter &rewriter, mlir::Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
template <typename OpTy>
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *targetOp,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"