#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
#include <optional>
usingnamespacemlir;
usingnamespacemlir::detail;
#define DEBUG_TYPE …
template <typename... Args>
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { … }
template <typename... Args>
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { … }
static OpBuilder::InsertPoint computeInsertPoint(Value value) { … }
namespace {
struct ConversionValueMapping { … };
}
Value ConversionValueMapping::lookupOrDefault(Value from,
Type desiredType) const { … }
Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const { … }
bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) { … }
namespace {
struct RewriterState { … };
class IRRewrite { … };
class BlockRewrite : public IRRewrite { … };
class CreateBlockRewrite : public BlockRewrite { … };
class EraseBlockRewrite : public BlockRewrite { … };
class InlineBlockRewrite : public BlockRewrite { … };
class MoveBlockRewrite : public BlockRewrite { … };
class BlockTypeConversionRewrite : public BlockRewrite { … };
class ReplaceBlockArgRewrite : public BlockRewrite { … };
class OperationRewrite : public IRRewrite { … };
class MoveOperationRewrite : public OperationRewrite { … };
class ModifyOperationRewrite : public OperationRewrite { … };
class ReplaceOperationRewrite : public OperationRewrite { … };
class CreateOperationRewrite : public OperationRewrite { … };
enum MaterializationKind { … };
class UnresolvedMaterializationRewrite : public OperationRewrite { … };
}
template <typename RewriteTy, typename R>
static bool hasRewrite(R &&rewrites, Operation *op) { … }
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener { … };
}
}
const ConversionConfig &IRRewrite::getConfig() const { … }
void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { … }
void BlockTypeConversionRewrite::rollback() { … }
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { … }
void ReplaceBlockArgRewrite::rollback() { … }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { … }
void ReplaceOperationRewrite::rollback() { … }
void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) { … }
void CreateOperationRewrite::rollback() { … }
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
const TypeConverter *converter, MaterializationKind kind)
: … { … }
void UnresolvedMaterializationRewrite::rollback() { … }
void ConversionPatternRewriterImpl::applyRewrites() { … }
RewriterState ConversionPatternRewriterImpl::getCurrentState() { … }
void ConversionPatternRewriterImpl::resetState(RewriterState state) { … }
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { … }
LogicalResult ConversionPatternRewriterImpl::remapValues(
StringRef valueDiagTag, std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
SmallVectorImpl<Value> &remapped) { … }
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { … }
bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { … }
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) { … }
Block *ConversionPatternRewriterImpl::applySignatureConversion(
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) { … }
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType, const TypeConverter *converter) { … }
void ConversionPatternRewriterImpl::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) { … }
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) { … }
void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { … }
void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) { … }
void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
Block *block, Block *srcBlock, Block::iterator before) { … }
void ConversionPatternRewriterImpl::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) { … }
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: … { … }
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { … }
void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { … }
void ConversionPatternRewriter::eraseOp(Operation *op) { … }
void ConversionPatternRewriter::eraseBlock(Block *block) { … }
Block *ConversionPatternRewriter::applySignatureConversion(
Block *block, TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) { … }
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) { … }
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) { … }
Value ConversionPatternRewriter::getRemappedValue(Value key) { … }
LogicalResult
ConversionPatternRewriter::getRemappedValues(ValueRange keys,
SmallVectorImpl<Value> &results) { … }
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
Block::iterator before,
ValueRange argValues) { … }
void ConversionPatternRewriter::startOpModification(Operation *op) { … }
void ConversionPatternRewriter::finalizeOpModification(Operation *op) { … }
void ConversionPatternRewriter::cancelOpModification(Operation *op) { … }
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { … }
LogicalResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const { … }
namespace {
LegalizationPatterns;
class OperationLegalizer { … };
}
OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config)
: … { … }
bool OperationLegalizer::isIllegal(Operation *op) const { … }
LogicalResult
OperationLegalizer::legalize(Operation *op,
ConversionPatternRewriter &rewriter) { … }
LogicalResult
OperationLegalizer::legalizeWithFold(Operation *op,
ConversionPatternRewriter &rewriter) { … }
LogicalResult
OperationLegalizer::legalizeWithPattern(Operation *op,
ConversionPatternRewriter &rewriter) { … }
bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter) { … }
LogicalResult
OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter,
RewriterState &curState) { … }
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
Operation *op, ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &impl, RewriterState &state,
RewriterState &newState) { … }
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) { … }
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) { … }
void OperationLegalizer::buildLegalizationGraph(
LegalizationPatterns &anyOpLegalizerPatterns,
DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { … }
void OperationLegalizer::computeLegalizationGraphBenefit(
LegalizationPatterns &anyOpLegalizerPatterns,
DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { … }
unsigned OperationLegalizer::computeOpLegalizationDepth(
OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { … }
unsigned OperationLegalizer::applyCostModelToPatterns(
LegalizationPatterns &patterns,
DenseMap<OperationName, unsigned> &minOpPatternDepth,
DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { … }
namespace {
enum OpConversionMode { … };
}
namespace mlir {
struct OperationConverter { … };
}
LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
Operation *op) { … }
static LogicalResult
legalizeUnresolvedMaterialization(RewriterBase &rewriter,
UnresolvedMaterializationRewrite *rewrite) { … }
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { … }
static Operation *findLiveUserOfReplaced(
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping) { … }
static std::pair<ValueRange, const TypeConverter *>
getReplacedValues(IRRewrite *rewrite) { … }
void OperationConverter::finalize(ConversionPatternRewriter &rewriter) { … }
void mlir::reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) { … }
void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
ArrayRef<Type> types) { … }
void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) { … }
void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
unsigned newInputNo,
unsigned newInputCount) { … }
void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
Value replacementValue) { … }
LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const { … }
Type TypeConverter::convertType(Type t) const { … }
LogicalResult
TypeConverter::convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const { … }
bool TypeConverter::isLegal(Type type) const { … }
bool TypeConverter::isLegal(Operation *op) const { … }
bool TypeConverter::isLegal(Region *region) const { … }
bool TypeConverter::isSignatureLegal(FunctionType ty) const { … }
LogicalResult
TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result) const { … }
LogicalResult
TypeConverter::convertSignatureArgs(TypeRange types,
SignatureConversion &result,
unsigned origInputOffset) const { … }
Value TypeConverter::materializeConversion(
ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
Location loc, Type resultType, ValueRange inputs) const { … }
std::optional<TypeConverter::SignatureConversion>
TypeConverter::convertBlockSignature(Block *block) const { … }
TypeConverter::AttributeConversionResult
TypeConverter::AttributeConversionResult::result(Attribute attr) { … }
TypeConverter::AttributeConversionResult
TypeConverter::AttributeConversionResult::na() { … }
TypeConverter::AttributeConversionResult
TypeConverter::AttributeConversionResult::abort() { … }
bool TypeConverter::AttributeConversionResult::hasResult() const { … }
bool TypeConverter::AttributeConversionResult::isNa() const { … }
bool TypeConverter::AttributeConversionResult::isAbort() const { … }
Attribute TypeConverter::AttributeConversionResult::getResult() const { … }
std::optional<Attribute>
TypeConverter::convertTypeAttribute(Type type, Attribute attr) const { … }
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
const TypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) { … }
namespace {
struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { … };
struct AnyFunctionOpInterfaceSignatureConversion
: public OpInterfaceConversionPattern<FunctionOpInterface> { … };
}
FailureOr<Operation *>
mlir::convertOpResultTypes(Operation *op, ValueRange operands,
const TypeConverter &converter,
ConversionPatternRewriter &rewriter) { … }
void mlir::populateFunctionOpInterfaceTypeConversionPattern(
StringRef functionLikeOpName, RewritePatternSet &patterns,
const TypeConverter &converter) { … }
void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
RewritePatternSet &patterns, const TypeConverter &converter) { … }
void ConversionTarget::setOpAction(OperationName op,
LegalizationAction action) { … }
void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
LegalizationAction action) { … }
auto ConversionTarget::getOpAction(OperationName op) const
-> std::optional<LegalizationAction> { … }
auto ConversionTarget::isLegal(Operation *op) const
-> std::optional<LegalOpDetails> { … }
bool ConversionTarget::isIllegal(Operation *op) const { … }
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
ConversionTarget::DynamicLegalityCallbackFn oldCallback,
ConversionTarget::DynamicLegalityCallbackFn newCallback) { … }
void ConversionTarget::setLegalityCallback(
OperationName name, const DynamicLegalityCallbackFn &callback) { … }
void ConversionTarget::markOpRecursivelyLegal(
OperationName name, const DynamicLegalityCallbackFn &callback) { … }
void ConversionTarget::setLegalityCallback(
ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) { … }
void ConversionTarget::setLegalityCallback(
const DynamicLegalityCallbackFn &callback) { … }
auto ConversionTarget::getOpInfo(OperationName op) const
-> std::optional<LegalizationInfo> { … }
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) { … }
void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) { … }
static FailureOr<SmallVector<Value>>
pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) { … }
void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { … }
#endif
LogicalResult mlir::applyPartialConversion(
ArrayRef<Operation *> ops, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns, ConversionConfig config) { … }
LogicalResult
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config) { … }
LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config) { … }
LogicalResult mlir::applyFullConversion(Operation *op,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config) { … }
LogicalResult mlir::applyAnalysisConversion(
ArrayRef<Operation *> ops, ConversionTarget &target,
const FrozenRewritePatternSet &patterns, ConversionConfig config) { … }
LogicalResult
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config) { … }