#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_ostream.h"
#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
#include <random>
#endif
usingnamespacemlir;
#define DEBUG_TYPE …
namespace {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
struct ExpensiveChecks : public RewriterBase::ForwardingListener {
ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
: RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
void computeFingerPrints(Operation *topLevel) {
this->topLevel = topLevel;
this->topLevelFingerPrint.emplace(topLevel);
topLevel->walk([&](Operation *op) {
fingerprints.try_emplace(op, op, false);
});
}
void clear() {
topLevel = nullptr;
topLevelFingerPrint.reset();
fingerprints.clear();
}
void notifyRewriteSuccess() {
if (!topLevel)
return;
if (failed(verify(topLevel)))
llvm::report_fatal_error("IR failed to verify after pattern application");
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint == afterFingerPrint) {
llvm::report_fatal_error(
"pattern returned success but IR did not change");
}
for (const auto &it : fingerprints) {
if (it.first == topLevel)
continue;
if (it.second !=
OperationFingerPrint(it.first, false)) {
llvm::report_fatal_error("operation finger print changed");
}
}
}
void notifyRewriteFailure() {
if (!topLevel)
return;
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint != afterFingerPrint) {
llvm::report_fatal_error("pattern returned failure but IR did change");
}
}
void notifyFoldingSuccess() {
if (!topLevel)
return;
if (failed(verify(topLevel)))
llvm::report_fatal_error("IR failed to verify after folding");
}
protected:
void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
void notifyBlockErased(Block *block) override {
RewriterBase::ForwardingListener::notifyBlockErased(block);
invalidateFingerPrint(block->getParentOp());
}
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
invalidateFingerPrint(op->getParentOp());
}
void notifyOperationModified(Operation *op) override {
RewriterBase::ForwardingListener::notifyOperationModified(op);
invalidateFingerPrint(op);
}
void notifyOperationErased(Operation *op) override {
RewriterBase::ForwardingListener::notifyOperationErased(op);
op->walk([this](Operation *op) { invalidateFingerPrint(op); });
}
DenseMap<Operation *, OperationFingerPrint> fingerprints;
Operation *topLevel = nullptr;
std::optional<OperationFingerPrint> topLevelFingerPrint;
};
#endif
#ifndef NDEBUG
static Operation *getDumpRootOp(Operation *op) {
if (Operation *parentOp = op->getParentOp())
return parentOp;
return op;
}
static void logSuccessfulFolding(Operation *op) {
llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
op->dump();
llvm::dbgs() << "\n\n";
}
#endif
class Worklist { … };
Worklist::Worklist() { … }
void Worklist::clear() { … }
bool Worklist::empty() const { … }
void Worklist::push(Operation *op) { … }
Operation *Worklist::pop() { … }
void Worklist::remove(Operation *op) { … }
void Worklist::reverse() { … }
#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
class RandomizedWorklist : public Worklist {
public:
RandomizedWorklist() : Worklist() {
generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
}
Operation *pop() {
Operation *op = nullptr;
do {
assert(!list.empty() && "cannot pop from empty worklist");
int64_t pos = generator() % list.size();
op = list[pos];
list.erase(list.begin() + pos);
for (int64_t i = pos, e = list.size(); i < e; ++i)
map[list[i]] = i;
map.erase(op);
} while (!op);
return op;
}
private:
std::minstd_rand0 generator;
};
#endif
class GreedyPatternRewriteDriver : public RewriterBase::Listener { … };
}
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
: … { … }
bool GreedyPatternRewriteDriver::processWorklist() { … }
void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { … }
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { … }
void GreedyPatternRewriteDriver::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) { … }
void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) { … }
void GreedyPatternRewriteDriver::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) { … }
void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) { … }
void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) { … }
void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) { … }
void GreedyPatternRewriteDriver::notifyOperationReplaced(
Operation *op, ValueRange replacement) { … }
void GreedyPatternRewriteDriver::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) { … }
namespace {
class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver { … };
}
RegionPatternRewriteDriver::RegionPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, Region ®ion)
: … { … }
namespace {
class GreedyPatternRewriteIteration
: public tracing::ActionImpl<GreedyPatternRewriteIteration> { … };
}
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { … }
LogicalResult
mlir::applyPatternsAndFoldGreedily(Region ®ion,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config, bool *changed) { … }
namespace {
class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { … };
}
MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
llvm::SmallDenseSet<Operation *, 4> *survivingOps)
: … { … }
LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
bool *changed) && { … }
static Region *findCommonAncestor(ArrayRef<Operation *> ops) { … }
LogicalResult mlir::applyOpPatternsAndFold(
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config, bool *changed, bool *allErased) { … }