//===- ConstantArgumentGlobalisation.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace fir {
#define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
#define DEBUG_TYPE "flang-constant-argument-globalisation-opt"
namespace {
unsigned uniqueLitId = 1;
class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
protected:
const mlir::DominanceInfo &di;
public:
using OpRewritePattern::OpRewritePattern;
CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
: OpRewritePattern(ctx), di(_di) {}
llvm::LogicalResult
matchAndRewrite(fir::CallOp callOp,
mlir::PatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
auto module = callOp->getParentOfType<mlir::ModuleOp>();
bool needUpdate = false;
fir::FirOpBuilder builder(rewriter, module);
llvm::SmallVector<mlir::Value> newOperands;
llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas;
for (const mlir::Value &a : callOp.getArgs()) {
auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
// We can convert arguments that are alloca, and that has
// the value by reference attribute. All else is just added
// to the argument list.
if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
newOperands.push_back(a);
continue;
}
mlir::Type varTy = alloca.getInType();
assert(!fir::hasDynamicSize(varTy) &&
"only expect statically sized scalars to be by value");
// Find immediate store with const argument
mlir::Operation *store = nullptr;
for (mlir::Operation *s : alloca->getUsers()) {
if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
// We can only deal with ONE store - if already found one,
// set to nullptr and exit the loop.
if (store) {
store = nullptr;
break;
}
store = s;
}
}
// If we didn't find any store, or multiple stores, add argument as is
// and move on.
if (!store) {
newOperands.push_back(a);
continue;
}
LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
mlir::Operation *definingOp = store->getOperand(0).getDefiningOp();
// If not a constant, add to operands and move on.
if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) {
// Unable to remove alloca arg
newOperands.push_back(a);
continue;
}
LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n");
std::string globalName =
"_global_const_." + std::to_string(uniqueLitId++);
assert(!builder.getNamedGlobal(globalName) &&
"We should have a unique name here");
if (llvm::none_of(allocas,
[alloca](auto x) { return x.first == alloca; })) {
allocas.push_back(std::make_pair(alloca, store));
}
auto loc = callOp.getLoc();
fir::GlobalOp global = builder.createGlobalConstant(
loc, varTy, globalName,
[&](fir::FirOpBuilder &builder) {
mlir::Operation *cln = definingOp->clone();
builder.insert(cln);
mlir::Value val =
builder.createConvert(loc, varTy, cln->getResult(0));
builder.create<fir::HasValueOp>(loc, val);
},
builder.createInternalLinkage());
mlir::Value addr = builder.create<fir::AddrOfOp>(loc, global.resultType(),
global.getSymbol());
newOperands.push_back(addr);
needUpdate = true;
}
if (needUpdate) {
auto loc = callOp.getLoc();
llvm::SmallVector<mlir::Type> newResultTypes;
newResultTypes.append(callOp.getResultTypes().begin(),
callOp.getResultTypes().end());
fir::CallOp newOp = builder.create<fir::CallOp>(
loc,
callOp.getCallee().has_value() ? callOp.getCallee().value()
: mlir::SymbolRefAttr{},
newResultTypes, newOperands);
// Copy all the attributes from the old to new op.
newOp->setAttrs(callOp->getAttrs());
rewriter.replaceOp(callOp, newOp);
for (auto a : allocas) {
if (a.first->hasOneUse()) {
// If the alloca is only used for a store and the call operand, the
// store is no longer required.
rewriter.eraseOp(a.second);
rewriter.eraseOp(a.first);
}
}
LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as "
<< newOp << '\n');
return mlir::success();
}
// Failure here just means "we couldn't do the conversion", which is
// perfectly acceptable to the upper layers of this function.
return mlir::failure();
}
};
// this pass attempts to convert immediate scalar literals in function calls
// to global constants to allow transformations such as Dead Argument
// Elimination
class ConstantArgumentGlobalisationOpt
: public fir::impl::ConstantArgumentGlobalisationOptBase<
ConstantArgumentGlobalisationOpt> {
public:
ConstantArgumentGlobalisationOpt() = default;
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
mod, std::move(patterns), config))) {
mlir::emitError(mod.getLoc(),
"error in constant globalisation optimization\n");
signalPassFailure();
}
}
};
} // namespace