#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE …
#define DEBUG_TYPE_FULL …
#define DEBUG_PRINT_AFTER_ALL …
#define DBGS() …
#define LDBG(X) …
#define FULL_LDBG(X) …
usingnamespacemlir;
static bool happensBefore(Operation *a, Operation *b) { … }
constexpr const Value transform::TransformState::kTopLevelValue;
transform::TransformState::TransformState(
Region *region, Operation *payloadRoot,
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options)
: … { … }
Operation *transform::TransformState::getTopLevel() const { … }
ArrayRef<Operation *>
transform::TransformState::getPayloadOpsView(Value value) const { … }
ArrayRef<Attribute> transform::TransformState::getParams(Value value) const { … }
ArrayRef<Value>
transform::TransformState::getPayloadValuesView(Value handleValue) const { … }
LogicalResult transform::TransformState::getHandlesForPayloadOp(
Operation *op, SmallVectorImpl<Value> &handles,
bool includeOutOfScope) const { … }
LogicalResult transform::TransformState::getHandlesForPayloadValue(
Value payloadValue, SmallVectorImpl<Value> &handles,
bool includeOutOfScope) const { … }
static DiagnosedSilenceableFailure dispatchMappedValues(
Value handle, ArrayRef<transform::MappedValue> values,
function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
function_ref<LogicalResult(ValueRange)> valuesFn) { … }
LogicalResult
transform::TransformState::mapBlockArgument(BlockArgument argument,
ArrayRef<MappedValue> values) { … }
LogicalResult transform::TransformState::mapBlockArguments(
Block::BlockArgListType arguments,
ArrayRef<SmallVector<MappedValue>> mapping) { … }
LogicalResult
transform::TransformState::setPayloadOps(Value value,
ArrayRef<Operation *> targets) { … }
LogicalResult
transform::TransformState::setPayloadValues(Value handle,
ValueRange payloadValues) { … }
LogicalResult transform::TransformState::setParams(Value value,
ArrayRef<Param> params) { … }
template <typename Mapping, typename Key, typename Mapped>
void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) { … }
void transform::TransformState::forgetMapping(Value opHandle,
ValueRange origOpFlatResults,
bool allowOutOfScope) { … }
void transform::TransformState::forgetValueMapping(
Value valueHandle, ArrayRef<Operation *> payloadOperations) { … }
LogicalResult
transform::TransformState::replacePayloadOp(Operation *op,
Operation *replacement) { … }
LogicalResult
transform::TransformState::replacePayloadValue(Value value, Value replacement) { … }
void transform::TransformState::recordOpHandleInvalidationOne(
OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
Operation *payloadOp, Value otherHandle, Value throughValue,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { … }
void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
Value payloadValue, Value valueHandle,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { … }
void transform::TransformState::recordOpHandleInvalidation(
OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
Value throughValue,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { … }
void transform::TransformState::recordValueHandleInvalidation(
OpOperand &valueHandle,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { … }
LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
transform::TransformOpInterface transform,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { … }
LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
transform::TransformOpInterface transform) { … }
template <typename T>
DiagnosedSilenceableFailure
checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
transform::TransformOpInterface transform,
unsigned operandNumber) { … }
void transform::TransformState::compactOpHandles() { … }
DiagnosedSilenceableFailure
transform::TransformState::applyTransform(TransformOpInterface transform) { … }
LogicalResult transform::TransformState::updateStateFromResults(
const TransformResults &results, ResultRange opResults) { … }
transform::TransformState::Extension::~Extension() = default;
LogicalResult
transform::TransformState::Extension::replacePayloadOp(Operation *op,
Operation *replacement) { … }
LogicalResult
transform::TransformState::Extension::replacePayloadValue(Value value,
Value replacement) { … }
transform::TransformState::RegionScope::~RegionScope() { … }
transform::TransformResults::TransformResults(unsigned numSegments) { … }
void transform::TransformResults::setParams(
OpResult value, ArrayRef<transform::TransformState::Param> params) { … }
void transform::TransformResults::setMappedValues(
OpResult handle, ArrayRef<MappedValue> values) { … }
void transform::TransformResults::setRemainingToEmpty(
transform::TransformOpInterface transform) { … }
ArrayRef<Operation *>
transform::TransformResults::get(unsigned resultNumber) const { … }
ArrayRef<transform::TransformState::Param>
transform::TransformResults::getParams(unsigned resultNumber) const { … }
ArrayRef<Value>
transform::TransformResults::getValues(unsigned resultNumber) const { … }
bool transform::TransformResults::isParam(unsigned resultNumber) const { … }
bool transform::TransformResults::isValue(unsigned resultNumber) const { … }
bool transform::TransformResults::isSet(unsigned resultNumber) const { … }
transform::TrackingListener::TrackingListener(TransformState &state,
TransformOpInterface op,
TrackingListenerConfig config)
: … { … }
Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) { … }
DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
Operation *&result, Operation *op, ValueRange newValues) const { … }
void transform::TrackingListener::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) { … }
void transform::TrackingListener::notifyOperationErased(Operation *op) { … }
void transform::TrackingListener::notifyOperationReplaced(
Operation *op, ValueRange newValues) { … }
transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() { … }
DiagnosedSilenceableFailure
transform::ErrorCheckingTrackingListener::checkAndResetError() { … }
bool transform::ErrorCheckingTrackingListener::failed() const { … }
void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) { … }
transform::TransformRewriter::TransformRewriter(
MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
: … { … }
bool transform::TransformRewriter::hasTrackingFailures() const { … }
void transform::TransformRewriter::silenceTrackingFailure() { … }
LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced(
Operation *op, Operation *replacement) { … }
LogicalResult
transform::detail::checkNestedConsumption(Location loc,
ArrayRef<Operation *> targets) { … }
LogicalResult
transform::detail::checkApplyToOne(Operation *transformOp,
Location payloadOpLoc,
const ApplyToEachResultList &partialResult) { … }
template <typename T>
static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) { … }
void transform::detail::setApplyToOneResults(
Operation *transformOp, TransformResults &transformResults,
ArrayRef<ApplyToEachResultList> results) { … }
LogicalResult transform::detail::appendValueMappings(
MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
ValueRange values, const transform::TransformState &state, bool flatten) { … }
void transform::detail::prepareValueMappings(
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
ValueRange values, const transform::TransformState &state) { … }
void transform::detail::forwardTerminatorOperands(
Block *block, transform::TransformState &state,
transform::TransformResults &results) { … }
transform::TransformState
transform::detail::makeTransformStateForTesting(Region *region,
Operation *payloadRoot) { … }
static void
remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
OpOperand *target,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
static void
remapArgumentEffects(Block &block, MutableArrayRef<OpOperand> operands,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
void transform::detail::getPotentialTopLevelEffects(
Operation *operation, Value root, Block &body,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
TransformState &state, Operation *op, Region ®ion) { … }
LogicalResult
transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { … }
void transform::detail::getParamProducerTransformOpTraitEffects(
Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
LogicalResult
transform::detail::verifyParamProducerTransformOpTrait(Operation *op) { … }
void transform::consumesHandle(
MutableArrayRef<OpOperand> handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
template <typename EffectTy, typename ResourceTy, typename Range>
static bool hasEffect(Range &&effects) { … }
bool transform::isHandleConsumed(Value handle,
transform::TransformOpInterface transform) { … }
void transform::producesHandle(
ResultRange handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
void transform::producesHandle(
MutableArrayRef<BlockArgument> handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
void transform::onlyReadsHandle(
MutableArrayRef<OpOperand> handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
void transform::modifiesPayload(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
void transform::onlyReadsPayload(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
bool transform::doesModifyPayload(transform::TransformOpInterface transform) { … }
bool transform::doesReadPayload(transform::TransformOpInterface transform) { … }
void transform::getConsumedBlockArguments(
Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) { … }
SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
TransformOpInterface transformOp) { … }
LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) { … }
LogicalResult transform::applyTransforms(
Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
const TransformOptions &options, bool enforceToplevelTransformOp,
function_ref<void(TransformState &)> stateInitializer,
function_ref<LogicalResult(TransformState &)> stateExporter) { … }
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"