//===- OptimizedBufferization.cpp - special cases for bufferization -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// In some special cases we can bufferize hlfir expressions in a more optimal
// way so as to avoid creating temporaries. This pass handles these. It should
// be run before the catch-all bufferization pass.
//
// This requires constant subexpression elimination to have already been run.
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Analysis/AliasAnalysis.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "flang/Optimizer/Transforms/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include <iterator>
#include <memory>
#include <mlir/Analysis/AliasAnalysis.h>
#include <optional>
namespace hlfir {
#define GEN_PASS_DEF_OPTIMIZEDBUFFERIZATION
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir
#define DEBUG_TYPE "opt-bufferization"
namespace {
/// This transformation should match in place modification of arrays.
/// It should match code of the form
/// %array = some.operation // array has shape %shape
/// %expr = hlfir.elemental %shape : [...] {
/// bb0(%arg0: index)
/// %0 = hlfir.designate %array(%arg0)
/// [...] // no other reads or writes to %array
/// hlfir.yield_element %element
/// }
/// hlfir.assign %expr to %array
/// hlfir.destroy %expr
///
/// Or
///
/// %read_array = some.operation // shape %shape
/// %expr = hlfir.elemental %shape : [...] {
/// bb0(%arg0: index)
/// %0 = hlfir.designate %read_array(%arg0)
/// [...]
/// hlfir.yield_element %element
/// }
/// %write_array = some.operation // with shape %shape
/// [...] // operations which don't effect write_array
/// hlfir.assign %expr to %write_array
/// hlfir.destroy %expr
///
/// In these cases, it is safe to turn the elemental into a do loop and modify
/// elements of %array in place without creating an extra temporary for the
/// elemental. We must check that there are no reads from the array at indexes
/// which might conflict with the assignment or any writes. For now we will keep
/// that strict and say that all reads must be at the elemental index (it is
/// probably safe to read from higher indices if lowering to an ordered loop).
class ElementalAssignBufferization
: public mlir::OpRewritePattern<hlfir::ElementalOp> {
private:
struct MatchInfo {
mlir::Value array;
hlfir::AssignOp assign;
hlfir::DestroyOp destroy;
};
/// determines if the transformation can be applied to this elemental
static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental);
public:
using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
llvm::LogicalResult
matchAndRewrite(hlfir::ElementalOp elemental,
mlir::PatternRewriter &rewriter) const override;
};
/// recursively collect all effects between start and end (including start, not
/// including end) start must properly dominate end, start and end must be in
/// the same block. If any operations with unknown effects are found,
/// std::nullopt is returned
static std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>>
getEffectsBetween(mlir::Operation *start, mlir::Operation *end) {
mlir::SmallVector<mlir::MemoryEffects::EffectInstance> ret;
if (start == end)
return ret;
assert(start->getBlock() && end->getBlock() && "TODO: block arguments");
assert(start->getBlock() == end->getBlock());
assert(mlir::DominanceInfo{}.properlyDominates(start, end));
mlir::Operation *nextOp = start;
while (nextOp && nextOp != end) {
std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>>
effects = mlir::getEffectsRecursively(nextOp);
if (!effects)
return std::nullopt;
ret.append(*effects);
nextOp = nextOp->getNextNode();
}
return ret;
}
/// If effect is a read or write on val, return whether it aliases.
/// Otherwise return mlir::AliasResult::NoAlias
static mlir::AliasResult
containsReadOrWriteEffectOn(const mlir::MemoryEffects::EffectInstance &effect,
mlir::Value val) {
fir::AliasAnalysis aliasAnalysis;
if (mlir::isa<mlir::MemoryEffects::Read, mlir::MemoryEffects::Write>(
effect.getEffect())) {
mlir::Value accessedVal = effect.getValue();
if (mlir::isa<fir::DebuggingResource>(effect.getResource()))
return mlir::AliasResult::NoAlias;
if (!accessedVal)
return mlir::AliasResult::MayAlias;
if (accessedVal == val)
return mlir::AliasResult::MustAlias;
// if the accessed value might alias val
mlir::AliasResult res = aliasAnalysis.alias(val, accessedVal);
if (!res.isNo())
return res;
// FIXME: alias analysis of fir.load
// follow this common pattern:
// %ref = hlfir.designate %array(%index)
// %val = fir.load $ref
if (auto designate = accessedVal.getDefiningOp<hlfir::DesignateOp>()) {
if (designate.getMemref() == val)
return mlir::AliasResult::MustAlias;
// if the designate is into an array that might alias val
res = aliasAnalysis.alias(val, designate.getMemref());
if (!res.isNo())
return res;
}
}
return mlir::AliasResult::NoAlias;
}
// Returns true if the given array references represent identical
// or completely disjoint array slices. The callers may use this
// method when the alias analysis reports an alias of some kind,
// so that we can run Fortran specific analysis on the array slices
// to see if they are identical or disjoint. Note that the alias
// analysis are not able to give such an answer about the references.
static bool areIdenticalOrDisjointSlices(mlir::Value ref1, mlir::Value ref2) {
if (ref1 == ref2)
return true;
auto des1 = ref1.getDefiningOp<hlfir::DesignateOp>();
auto des2 = ref2.getDefiningOp<hlfir::DesignateOp>();
// We only support a pair of designators right now.
if (!des1 || !des2)
return false;
if (des1.getMemref() != des2.getMemref()) {
// If the bases are different, then there is unknown overlap.
LLVM_DEBUG(llvm::dbgs() << "No identical base for:\n"
<< des1 << "and:\n"
<< des2 << "\n");
return false;
}
// Require all components of the designators to be the same.
// It might be too strict, e.g. we may probably allow for
// different type parameters.
if (des1.getComponent() != des2.getComponent() ||
des1.getComponentShape() != des2.getComponentShape() ||
des1.getSubstring() != des2.getSubstring() ||
des1.getComplexPart() != des2.getComplexPart() ||
des1.getTypeparams() != des2.getTypeparams()) {
LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n"
<< des1 << "and:\n"
<< des2 << "\n");
return false;
}
if (des1.getIsTriplet() != des2.getIsTriplet()) {
LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n"
<< des1 << "and:\n"
<< des2 << "\n");
return false;
}
// Analyze the subscripts.
// For example:
// hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0) shape %9
// hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1) shape %9
//
// If all the triplets (section speficiers) are the same, then
// we do not care if %0 is equal to %1 - the slices are either
// identical or completely disjoint.
auto des1It = des1.getIndices().begin();
auto des2It = des2.getIndices().begin();
bool identicalTriplets = true;
for (bool isTriplet : des1.getIsTriplet()) {
if (isTriplet) {
for (int i = 0; i < 3; ++i)
if (*des1It++ != *des2It++) {
LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n"
<< des1 << "and:\n"
<< des2 << "\n");
identicalTriplets = false;
break;
}
} else {
++des1It;
++des2It;
}
}
if (identicalTriplets)
return true;
// See if we can prove that any of the triplets do not overlap.
// This is mostly a Polyhedron/nf performance hack that looks for
// particular relations between the lower and upper bounds
// of the array sections, e.g. for any positive constant C:
// X:Y does not overlap with (Y+C):Z
// X:Y does not overlap with Z:(X-C)
auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
auto *op = v.getDefiningOp();
while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
op = conv.getValue().getDefiningOp();
return op;
};
auto isPositiveConstant = [](mlir::Value v) -> bool {
if (auto conOp =
mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp()))
if (auto iattr = mlir::dyn_cast<mlir::IntegerAttr>(conOp.getValue()))
return iattr.getInt() > 0;
return false;
};
auto *op1 = removeConvert(v1);
auto *op2 = removeConvert(v2);
if (!op1 || !op2)
return false;
if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
if ((addi.getLhs().getDefiningOp() == op1 &&
isPositiveConstant(addi.getRhs())) ||
(addi.getRhs().getDefiningOp() == op1 &&
isPositiveConstant(addi.getLhs())))
return true;
if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
if (subi.getLhs().getDefiningOp() == op2 &&
isPositiveConstant(subi.getRhs()))
return true;
return false;
};
des1It = des1.getIndices().begin();
des2It = des2.getIndices().begin();
for (bool isTriplet : des1.getIsTriplet()) {
if (isTriplet) {
mlir::Value des1Lb = *des1It++;
mlir::Value des1Ub = *des1It++;
mlir::Value des2Lb = *des2It++;
mlir::Value des2Ub = *des2It++;
// Ignore strides.
++des1It;
++des2It;
if (displacedByConstant(des1Ub, des2Lb) ||
displacedByConstant(des2Ub, des1Lb))
return true;
} else {
++des1It;
++des2It;
}
}
return false;
}
std::optional<ElementalAssignBufferization::MatchInfo>
ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
mlir::Operation::user_range users = elemental->getUsers();
// the only uses of the elemental should be the assignment and the destroy
if (std::distance(users.begin(), users.end()) != 2) {
LLVM_DEBUG(llvm::dbgs() << "Too many uses of the elemental\n");
return std::nullopt;
}
// If the ElementalOp must produce a temporary (e.g. for
// finalization purposes), then we cannot inline it.
if (hlfir::elementalOpMustProduceTemp(elemental)) {
LLVM_DEBUG(llvm::dbgs() << "ElementalOp must produce a temp\n");
return std::nullopt;
}
MatchInfo match;
for (mlir::Operation *user : users)
mlir::TypeSwitch<mlir::Operation *, void>(user)
.Case([&](hlfir::AssignOp op) { match.assign = op; })
.Case([&](hlfir::DestroyOp op) { match.destroy = op; });
if (!match.assign || !match.destroy) {
LLVM_DEBUG(llvm::dbgs() << "Couldn't find assign or destroy\n");
return std::nullopt;
}
// the array is what the elemental is assigned into
// TODO: this could be extended to also allow hlfir.expr by first bufferizing
// the incoming expression
match.array = match.assign.getLhs();
mlir::Type arrayType = mlir::dyn_cast<fir::SequenceType>(
fir::unwrapPassByRefType(match.array.getType()));
if (!arrayType)
return std::nullopt;
// require that the array elements are trivial
// TODO: this is just to make the pass easier to think about. Not an inherent
// limitation
mlir::Type eleTy = hlfir::getFortranElementType(arrayType);
if (!fir::isa_trivial(eleTy))
return std::nullopt;
// the array must have the same shape as the elemental. CSE should have
// deduplicated the fir.shape operations where they are provably the same
// so we just have to check for the same ssa value
// TODO: add more ways of getting the shape of the array
mlir::Value arrayShape;
if (match.array.getDefiningOp())
arrayShape =
mlir::TypeSwitch<mlir::Operation *, mlir::Value>(
match.array.getDefiningOp())
.Case([](hlfir::DesignateOp designate) {
return designate.getShape();
})
.Case([](hlfir::DeclareOp declare) { return declare.getShape(); })
.Default([](mlir::Operation *) { return mlir::Value{}; });
if (!arrayShape) {
LLVM_DEBUG(llvm::dbgs() << "Can't get shape of " << match.array << " at "
<< elemental->getLoc() << "\n");
return std::nullopt;
}
if (arrayShape != elemental.getShape()) {
// f2018 10.2.1.2 (3) requires the lhs and rhs of an assignment to be
// conformable unless the lhs is an allocatable array. In HLFIR we can
// see this from the presence or absence of the realloc attribute on
// hlfir.assign. If it is not a realloc assignment, we can trust that
// the shapes do conform
if (match.assign.getRealloc())
return std::nullopt;
}
// the transformation wants to apply the elemental in a do-loop at the
// hlfir.assign, check there are no effects which make this unsafe
// keep track of any values written to in the elemental, as these can't be
// read from between the elemental and the assignment
// likewise, values read in the elemental cannot be written to between the
// elemental and the assign
mlir::SmallVector<mlir::Value, 1> notToBeAccessedBeforeAssign;
// any accesses to the array between the array and the assignment means it
// would be unsafe to move the elemental to the assignment
notToBeAccessedBeforeAssign.push_back(match.array);
// 1) side effects in the elemental body - it isn't sufficient to just look
// for ordered elementals because we also cannot support out of order reads
std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>>
effects = getEffectsBetween(&elemental.getBody()->front(),
elemental.getBody()->getTerminator());
if (!effects) {
LLVM_DEBUG(llvm::dbgs()
<< "operation with unknown effects inside elemental\n");
return std::nullopt;
}
for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
mlir::AliasResult res = containsReadOrWriteEffectOn(effect, match.array);
if (res.isNo()) {
if (mlir::isa<mlir::MemoryEffects::Write, mlir::MemoryEffects::Read>(
effect.getEffect()))
if (effect.getValue())
notToBeAccessedBeforeAssign.push_back(effect.getValue());
// this is safe in the elemental
continue;
}
// don't allow any aliasing writes in the elemental
if (mlir::isa<mlir::MemoryEffects::Write>(effect.getEffect())) {
LLVM_DEBUG(llvm::dbgs() << "write inside the elemental body\n");
return std::nullopt;
}
// allow if and only if the reads are from the elemental indices, in order
// => each iteration doesn't read values written by other iterations
// don't allow reads from a different value which may alias: fir alias
// analysis isn't precise enough to tell us if two aliasing arrays overlap
// exactly or only partially. If they overlap partially, a designate at the
// elemental indices could be accessing different elements: e.g. we could
// designate two slices of the same array at different start indexes. These
// two MustAlias but index 1 of one array isn't the same element as index 1
// of the other array.
if (!res.isPartial()) {
if (auto designate =
effect.getValue().getDefiningOp<hlfir::DesignateOp>()) {
if (!areIdenticalOrDisjointSlices(match.array, designate.getMemref())) {
LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
<< " at " << elemental.getLoc() << "\n");
return std::nullopt;
}
auto indices = designate.getIndices();
auto elementalIndices = elemental.getIndices();
if (indices.size() != elementalIndices.size()) {
LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
<< " at " << elemental.getLoc() << "\n");
return std::nullopt;
}
if (std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
elementalIndices.end()))
continue;
}
}
LLVM_DEBUG(llvm::dbgs() << "disallowed side-effect: " << effect.getValue()
<< " for " << elemental.getLoc() << "\n");
return std::nullopt;
}
// 2) look for conflicting effects between the elemental and the assignment
effects = getEffectsBetween(elemental->getNextNode(), match.assign);
if (!effects) {
LLVM_DEBUG(
llvm::dbgs()
<< "operation with unknown effects between elemental and assign\n");
return std::nullopt;
}
for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
// not safe to access anything written in the elemental as this write
// will be moved to the assignment
for (mlir::Value val : notToBeAccessedBeforeAssign) {
mlir::AliasResult res = containsReadOrWriteEffectOn(effect, val);
if (!res.isNo()) {
LLVM_DEBUG(llvm::dbgs()
<< "diasllowed side-effect: " << effect.getValue() << " for "
<< elemental.getLoc() << "\n");
return std::nullopt;
}
}
}
return match;
}
llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
hlfir::ElementalOp elemental, mlir::PatternRewriter &rewriter) const {
std::optional<MatchInfo> match = findMatch(elemental);
if (!match)
return rewriter.notifyMatchFailure(
elemental, "cannot prove safety of ElementalAssignBufferization");
mlir::Location loc = elemental->getLoc();
fir::FirOpBuilder builder(rewriter, elemental.getOperation());
auto extents = hlfir::getIndexExtents(loc, builder, elemental.getShape());
// create the loop at the assignment
builder.setInsertionPoint(match->assign);
// Generate a loop nest looping around the hlfir.elemental shape and clone
// hlfir.elemental region inside the inner loop
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
loopNest.oneBasedIndices);
hlfir::Entity elementValue{yield.getElementValue()};
rewriter.eraseOp(yield);
// Assign the element value to the array element for this iteration.
auto arrayElement = hlfir::getElementAt(
loc, builder, hlfir::Entity{match->array}, loopNest.oneBasedIndices);
builder.create<hlfir::AssignOp>(
loc, elementValue, arrayElement, /*realloc=*/false,
/*keep_lhs_length_if_realloc=*/false, match->assign.getTemporaryLhs());
rewriter.eraseOp(match->assign);
rewriter.eraseOp(match->destroy);
rewriter.eraseOp(elemental);
return mlir::success();
}
/// Expand hlfir.assign of a scalar RHS to array LHS into a loop nest
/// of element-by-element assignments:
/// hlfir.assign %cst to %0 : f32, !fir.ref<!fir.array<6x6xf32>>
/// into:
/// fir.do_loop %arg0 = %c1 to %c6 step %c1 unordered {
/// fir.do_loop %arg1 = %c1 to %c6 step %c1 unordered {
/// %1 = hlfir.designate %0 (%arg1, %arg0) :
/// (!fir.ref<!fir.array<6x6xf32>>, index, index) -> !fir.ref<f32>
/// hlfir.assign %cst to %1 : f32, !fir.ref<f32>
/// }
/// }
class BroadcastAssignBufferization
: public mlir::OpRewritePattern<hlfir::AssignOp> {
private:
public:
using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern;
llvm::LogicalResult
matchAndRewrite(hlfir::AssignOp assign,
mlir::PatternRewriter &rewriter) const override;
};
llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const {
// Since RHS is a scalar and LHS is an array, LHS must be allocated
// in a conforming Fortran program, and LHS cannot be reallocated
// as a result of the assignment. So we can ignore isAllocatableAssignment
// and do the transformation always.
mlir::Value rhs = assign.getRhs();
if (!fir::isa_trivial(rhs.getType()))
return rewriter.notifyMatchFailure(
assign, "AssignOp's RHS is not a trivial scalar");
hlfir::Entity lhs{assign.getLhs()};
if (!lhs.isArray())
return rewriter.notifyMatchFailure(assign,
"AssignOp's LHS is not an array");
mlir::Type eleTy = lhs.getFortranElementType();
if (!fir::isa_trivial(eleTy))
return rewriter.notifyMatchFailure(
assign, "AssignOp's LHS data type is not trivial");
mlir::Location loc = assign->getLoc();
fir::FirOpBuilder builder(rewriter, assign.getOperation());
builder.setInsertionPoint(assign);
lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
mlir::Value shape = hlfir::genShape(loc, builder, lhs);
llvm::SmallVector<mlir::Value> extents =
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
auto arrayElement =
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
rewriter.eraseOp(assign);
return mlir::success();
}
/// Expand hlfir.assign of array RHS to array LHS into a loop nest
/// of element-by-element assignments:
/// hlfir.assign %4 to %5 : !fir.ref<!fir.array<3x3xf32>>,
/// !fir.ref<!fir.array<3x3xf32>>
/// into:
/// fir.do_loop %arg1 = %c1 to %c3 step %c1 unordered {
/// fir.do_loop %arg2 = %c1 to %c3 step %c1 unordered {
/// %6 = hlfir.designate %4 (%arg2, %arg1) :
/// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32>
/// %7 = fir.load %6 : !fir.ref<f32>
/// %8 = hlfir.designate %5 (%arg2, %arg1) :
/// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32>
/// hlfir.assign %7 to %8 : f32, !fir.ref<f32>
/// }
/// }
///
/// The transformation is correct only when LHS and RHS do not alias.
/// This transformation does not support runtime checking for
/// non-conforming LHS/RHS arrays' shapes currently.
class VariableAssignBufferization
: public mlir::OpRewritePattern<hlfir::AssignOp> {
private:
public:
using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern;
llvm::LogicalResult
matchAndRewrite(hlfir::AssignOp assign,
mlir::PatternRewriter &rewriter) const override;
};
llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const {
if (assign.isAllocatableAssignment())
return rewriter.notifyMatchFailure(assign, "AssignOp may imply allocation");
hlfir::Entity rhs{assign.getRhs()};
// TODO: ExprType check is here to avoid conflicts with
// ElementalAssignBufferization pattern. We need to combine
// these matchers into a single one that applies to AssignOp.
if (mlir::isa<hlfir::ExprType>(rhs.getType()))
return rewriter.notifyMatchFailure(assign, "RHS is not in memory");
if (!rhs.isArray())
return rewriter.notifyMatchFailure(assign,
"AssignOp's RHS is not an array");
mlir::Type rhsEleTy = rhs.getFortranElementType();
if (!fir::isa_trivial(rhsEleTy))
return rewriter.notifyMatchFailure(
assign, "AssignOp's RHS data type is not trivial");
hlfir::Entity lhs{assign.getLhs()};
if (!lhs.isArray())
return rewriter.notifyMatchFailure(assign,
"AssignOp's LHS is not an array");
mlir::Type lhsEleTy = lhs.getFortranElementType();
if (!fir::isa_trivial(lhsEleTy))
return rewriter.notifyMatchFailure(
assign, "AssignOp's LHS data type is not trivial");
if (lhsEleTy != rhsEleTy)
return rewriter.notifyMatchFailure(assign,
"RHS/LHS element types mismatch");
fir::AliasAnalysis aliasAnalysis;
mlir::AliasResult aliasRes = aliasAnalysis.alias(lhs, rhs);
// TODO: use areIdenticalOrDisjointSlices() to check if
// we can still do the expansion.
if (!aliasRes.isNo()) {
LLVM_DEBUG(llvm::dbgs() << "VariableAssignBufferization:\n"
<< "\tLHS: " << lhs << "\n"
<< "\tRHS: " << rhs << "\n"
<< "\tALIAS: " << aliasRes << "\n");
return rewriter.notifyMatchFailure(assign, "RHS/LHS may alias");
}
mlir::Location loc = assign->getLoc();
fir::FirOpBuilder builder(rewriter, assign.getOperation());
builder.setInsertionPoint(assign);
rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs);
lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
mlir::Value shape = hlfir::genShape(loc, builder, lhs);
llvm::SmallVector<mlir::Value> extents =
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
auto rhsArrayElement =
hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);
auto lhsArrayElement =
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
builder.create<hlfir::AssignOp>(loc, rhsArrayElement, lhsArrayElement);
rewriter.eraseOp(assign);
return mlir::success();
}
using GenBodyFn =
std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
const llvm::SmallVectorImpl<mlir::Value> &)>;
static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Value init,
mlir::Value shape, GenBodyFn genBody) {
auto extents = hlfir::getIndexExtents(loc, builder, shape);
mlir::Value reduction = init;
mlir::IndexType idxTy = builder.getIndexType();
mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
// Create a reduction loop nest. We use one-based indices so that they can be
// passed to the elemental, and reverse the order so that they can be
// generated in column-major order for better performance.
llvm::SmallVector<mlir::Value> indices(extents.size(), mlir::Value{});
for (unsigned i = 0; i < extents.size(); ++i) {
auto loop = builder.create<fir::DoLoopOp>(
loc, oneIdx, extents[extents.size() - i - 1], oneIdx, false,
/*finalCountValue=*/false, reduction);
reduction = loop.getRegionIterArgs()[0];
indices[extents.size() - i - 1] = loop.getInductionVar();
// Set insertion point to the loop body so that the next loop
// is inserted inside the current one.
builder.setInsertionPointToStart(loop.getBody());
}
// Generate the body
reduction = genBody(builder, loc, reduction, indices);
// Unwind the loop nest.
for (unsigned i = 0; i < extents.size(); ++i) {
auto result = builder.create<fir::ResultOp>(loc, reduction);
auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
reduction = loop.getResult(0);
// Set insertion point after the loop operation that we have
// just processed.
builder.setInsertionPointAfter(loop.getOperation());
}
return reduction;
}
auto makeMinMaxInitValGenerator(bool isMax) {
return [isMax](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) -> mlir::Value {
if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
return builder.createRealConstant(loc, elementType, limit);
}
unsigned bits = elementType.getIntOrFloatBitWidth();
int64_t limitInt =
isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
: llvm::APInt::getSignedMaxValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, elementType, limitInt);
};
}
mlir::Value generateMinMaxComparison(fir::FirOpBuilder builder,
mlir::Location loc, mlir::Value elem,
mlir::Value reduction, bool isMax) {
if (mlir::isa<mlir::FloatType>(reduction.getType())) {
// For FP reductions we want the first smallest value to be used, that
// is not NaN. A OGL/OLT condition will usually work for this unless all
// the values are Nan or Inf. This follows the same logic as
// NumericCompare for Minloc/Maxlox in extrema.cpp.
mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
loc,
isMax ? mlir::arith::CmpFPredicate::OGT
: mlir::arith::CmpFPredicate::OLT,
elem, reduction);
mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
return builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
} else if (mlir::isa<mlir::IntegerType>(reduction.getType())) {
return builder.create<mlir::arith::CmpIOp>(
loc,
isMax ? mlir::arith::CmpIPredicate::sgt
: mlir::arith::CmpIPredicate::slt,
elem, reduction);
}
llvm_unreachable("unsupported type");
}
/// Given a reduction operation with an elemental/designate source, attempt to
/// generate a do-loop to perform the operation inline.
/// %e = hlfir.elemental %shape unordered
/// %r = hlfir.count %e
/// =>
/// %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init)
/// %i = <inline elemental>
/// %c = <reduce count> %i
/// fir.result %c
template <typename Op>
class ReductionConversion : public mlir::OpRewritePattern<Op> {
public:
using mlir::OpRewritePattern<Op>::OpRewritePattern;
llvm::LogicalResult
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = op.getLoc();
// Select source and validate its arguments.
mlir::Value source;
bool valid = false;
if constexpr (std::is_same_v<Op, hlfir::AnyOp> ||
std::is_same_v<Op, hlfir::AllOp> ||
std::is_same_v<Op, hlfir::CountOp>) {
source = op.getMask();
valid = !op.getDim();
} else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
std::is_same_v<Op, hlfir::MinvalOp>) {
source = op.getArray();
valid = !op.getDim() && !op.getMask();
} else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
std::is_same_v<Op, hlfir::MinlocOp>) {
source = op.getArray();
valid = !op.getDim() && !op.getMask() && !op.getBack();
}
if (!valid)
return rewriter.notifyMatchFailure(
op, "Currently does not accept optional arguments");
hlfir::ElementalOp elemental;
hlfir::DesignateOp designate;
mlir::Value shape;
if ((elemental = source.template getDefiningOp<hlfir::ElementalOp>())) {
shape = elemental.getOperand(0);
} else if ((designate =
source.template getDefiningOp<hlfir::DesignateOp>())) {
shape = designate.getShape();
} else {
return rewriter.notifyMatchFailure(op, "Did not find valid argument");
}
auto inlineSource =
[elemental, &designate](
fir::FirOpBuilder builder, mlir::Location loc,
const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
if (elemental) {
// Inline the elemental and get the value from it.
auto yield = inlineElementalOp(loc, builder, elemental, indices);
auto tmp = yield.getElementValue();
yield->erase();
return tmp;
}
if (designate) {
// Create a designator over designator, then load the reference.
auto resEntity = hlfir::Entity{designate.getResult()};
auto tmp = builder.create<hlfir::DesignateOp>(
loc, getVariableElementType(resEntity), designate, indices);
return builder.create<fir::LoadOp>(loc, tmp);
}
llvm_unreachable("unsupported type");
};
fir::KindMapping kindMap =
fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>());
fir::FirOpBuilder builder{op, kindMap};
mlir::Value init;
GenBodyFn genBodyFn;
if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
init = builder.createIntegerConstant(loc, builder.getI1Type(), 0);
genBodyFn =
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &indices)
-> mlir::Value {
// Conditionally set the reduction variable.
mlir::Value cond = builder.create<fir::ConvertOp>(
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
return builder.create<mlir::arith::OrIOp>(loc, reduction, cond);
};
} else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
init = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
genBodyFn =
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &indices)
-> mlir::Value {
// Conditionally set the reduction variable.
mlir::Value cond = builder.create<fir::ConvertOp>(
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
return builder.create<mlir::arith::AndIOp>(loc, reduction, cond);
};
} else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
init = builder.createIntegerConstant(loc, op.getType(), 0);
genBodyFn =
[inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &indices)
-> mlir::Value {
// Conditionally add one to the current value
mlir::Value cond = builder.create<fir::ConvertOp>(
loc, builder.getI1Type(), inlineSource(builder, loc, indices));
mlir::Value one =
builder.createIntegerConstant(loc, reduction.getType(), 1);
mlir::Value add1 =
builder.create<mlir::arith::AddIOp>(loc, reduction, one);
return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
reduction);
};
} else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
std::is_same_v<Op, hlfir::MinlocOp>) {
// TODO: implement minloc/maxloc conversion.
return rewriter.notifyMatchFailure(
op, "Currently minloc/maxloc is not handled");
} else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
std::is_same_v<Op, hlfir::MinvalOp>) {
bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>;
init = makeMinMaxInitValGenerator(isMax)(builder, loc, op.getType());
genBodyFn = [inlineSource,
isMax](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &indices)
-> mlir::Value {
mlir::Value val = inlineSource(builder, loc, indices);
mlir::Value cmp =
generateMinMaxComparison(builder, loc, val, reduction, isMax);
return builder.create<mlir::arith::SelectOp>(loc, cmp, val, reduction);
};
} else {
llvm_unreachable("unsupported type");
}
mlir::Value res =
generateReductionLoop(builder, loc, init, shape, genBodyFn);
if (res.getType() != op.getType())
res = builder.create<fir::ConvertOp>(loc, op.getType(), res);
// Check if the op was the only user of the source (apart from a destroy),
// and remove it if so.
mlir::Operation *sourceOp = source.getDefiningOp();
mlir::Operation::user_range srcUsers = sourceOp->getUsers();
hlfir::DestroyOp srcDestroy;
if (std::distance(srcUsers.begin(), srcUsers.end()) == 2) {
srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*srcUsers.begin());
if (!srcDestroy)
srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++srcUsers.begin());
}
rewriter.replaceOp(op, res);
if (srcDestroy) {
rewriter.eraseOp(srcDestroy);
rewriter.eraseOp(sourceOp);
}
return mlir::success();
}
};
// Look for minloc(mask=elemental) and generate the minloc loop with
// inlined elemental.
// %e = hlfir.elemental %shape ({ ... })
// %m = hlfir.minloc %array mask %e
template <typename Op>
class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
public:
using mlir::OpRewritePattern<Op>::OpRewritePattern;
llvm::LogicalResult
matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override {
if (!mloc.getMask() || mloc.getDim() || mloc.getBack())
return rewriter.notifyMatchFailure(mloc,
"Did not find valid minloc/maxloc");
bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>;
auto elemental =
mloc.getMask().template getDefiningOp<hlfir::ElementalOp>();
if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
return rewriter.notifyMatchFailure(mloc, "Did not find elemental");
mlir::Value array = mloc.getArray();
unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0];
mlir::Type arrayType = array.getType();
if (!mlir::isa<fir::BoxType>(arrayType))
return rewriter.notifyMatchFailure(
mloc, "Currently requires a boxed type input");
mlir::Type elementType = hlfir::getFortranElementType(arrayType);
if (!fir::isa_trivial(elementType))
return rewriter.notifyMatchFailure(
mloc, "Character arrays are currently not handled");
mlir::Location loc = mloc.getLoc();
fir::FirOpBuilder builder{rewriter, mloc.getOperation()};
mlir::Value resultArr = builder.createTemporary(
loc, fir::SequenceType::get(
rank, hlfir::getFortranElementType(mloc.getType())));
auto init = makeMinMaxInitValGenerator(isMax);
auto genBodyOp =
[&rank, &resultArr, &elemental, isMax](
fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType, mlir::Value array, mlir::Value flagRef,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
// We are in the innermost loop: generate the elemental inline
mlir::Value oneIdx =
builder.createIntegerConstant(loc, builder.getIndexType(), 1);
llvm::SmallVector<mlir::Value> oneBasedIndices;
llvm::transform(
indices, std::back_inserter(oneBasedIndices), [&](mlir::Value V) {
return builder.create<mlir::arith::AddIOp>(loc, V, oneIdx);
});
hlfir::YieldElementOp yield =
hlfir::inlineElementalOp(loc, builder, elemental, oneBasedIndices);
mlir::Value maskElem = yield.getElementValue();
yield->erase();
mlir::Type ifCompatType = builder.getI1Type();
mlir::Value ifCompatElem =
builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
fir::IfOp maskIfOp =
builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
/*withElseRegion=*/true);
builder.setInsertionPointToStart(&maskIfOp.getThenRegion().front());
// Set flag that mask was true at some point
mlir::Value flagSet = builder.createIntegerConstant(
loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
mlir::Value addr = hlfir::getElementAt(loc, builder, hlfir::Entity{array},
oneBasedIndices);
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
// Compare with the max reduction value
mlir::Value cmp =
generateMinMaxComparison(builder, loc, elem, reduction, isMax);
// The condition used for the loop is isFirst || <the condition above>.
isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
isFirst = builder.create<mlir::arith::XOrIOp>(
loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
// Set the new coordinate to the result
fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
/*withElseRegion*/ true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
mlir::Type resultElemTy =
hlfir::getFortranElementType(resultArr.getType());
mlir::Type returnRefTy = builder.getRefType(resultElemTy);
mlir::IndexType idxTy = builder.getIndexType();
for (unsigned int i = 0; i < rank; ++i) {
mlir::Value index = builder.createIntegerConstant(loc, idxTy, i + 1);
mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>(
loc, returnRefTy, resultArr, index);
mlir::Value fortranIndex = builder.create<fir::ConvertOp>(
loc, resultElemTy, oneBasedIndices[i]);
builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
}
builder.create<fir::ResultOp>(loc, elem);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<fir::ResultOp>(loc, reduction);
builder.setInsertionPointAfter(ifOp);
// Close the mask if
builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
builder.setInsertionPointToStart(&maskIfOp.getElseRegion().front());
builder.create<fir::ResultOp>(loc, reduction);
builder.setInsertionPointAfter(maskIfOp);
return maskIfOp.getResult(0);
};
auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
const mlir::Type &resultElemType, mlir::Value resultArr,
mlir::Value index) {
mlir::Type resultRefTy = builder.getRefType(resultElemType);
mlir::Value oneIdx =
builder.createIntegerConstant(loc, builder.getIndexType(), 1);
index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx);
return builder.create<hlfir::DesignateOp>(loc, resultRefTy, resultArr,
index);
};
// Initialize the result
mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
mlir::Type resultRefTy = builder.getRefType(resultElemTy);
mlir::Value returnValue =
builder.createIntegerConstant(loc, resultElemTy, 0);
for (unsigned int i = 0; i < rank; ++i) {
mlir::Value index =
builder.createIntegerConstant(loc, builder.getIndexType(), i + 1);
mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>(
loc, resultRefTy, resultArr, index);
builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
}
fir::genMinMaxlocReductionLoop(builder, array, init, genBodyOp, getAddrFn,
rank, elementType, loc, builder.getI1Type(),
resultArr, false);
mlir::Value asExpr = builder.create<hlfir::AsExprOp>(
loc, resultArr, builder.createBool(loc, false));
// Check all the users - the destroy is no longer required, and any assign
// can use resultArr directly so that VariableAssignBufferization in this
// pass can optimize the results. Other operations are replaces with an
// AsExpr for the temporary resultArr.
llvm::SmallVector<hlfir::DestroyOp> destroys;
llvm::SmallVector<hlfir::AssignOp> assigns;
for (auto user : mloc->getUsers()) {
if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user))
destroys.push_back(destroy);
else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user))
assigns.push_back(assign);
}
// Check if the minloc/maxloc was the only user of the elemental (apart from
// a destroy), and remove it if so.
mlir::Operation::user_range elemUsers = elemental->getUsers();
hlfir::DestroyOp elemDestroy;
if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
if (!elemDestroy)
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
}
for (auto d : destroys)
rewriter.eraseOp(d);
for (auto a : assigns)
a.setOperand(0, resultArr);
rewriter.replaceOp(mloc, asExpr);
if (elemDestroy) {
rewriter.eraseOp(elemDestroy);
rewriter.eraseOp(elemental);
}
return mlir::success();
}
};
class OptimizedBufferizationPass
: public hlfir::impl::OptimizedBufferizationBase<
OptimizedBufferizationPass> {
public:
void runOnOperation() override {
mlir::MLIRContext *context = &getContext();
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
mlir::RewritePatternSet patterns(context);
// TODO: right now the patterns are non-conflicting,
// but it might be better to run this pass on hlfir.assign
// operations and decide which transformation to apply
// at one place (e.g. we may use some heuristics and
// choose different optimization strategies).
// This requires small code reordering in ElementalAssignBufferization.
patterns.insert<ElementalAssignBufferization>(context);
patterns.insert<BroadcastAssignBufferization>(context);
patterns.insert<VariableAssignBufferization>(context);
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
patterns.insert<ReductionConversion<hlfir::AllOp>>(context);
// TODO: implement basic minloc/maxloc conversion.
// patterns.insert<ReductionConversion<hlfir::MaxlocOp>>(context);
// patterns.insert<ReductionConversion<hlfir::MinlocOp>>(context);
patterns.insert<ReductionConversion<hlfir::MaxvalOp>>(context);
patterns.insert<ReductionConversion<hlfir::MinvalOp>>(context);
patterns.insert<ReductionMaskConversion<hlfir::MinlocOp>>(context);
patterns.insert<ReductionMaskConversion<hlfir::MaxlocOp>>(context);
// TODO: implement masked minval/maxval conversion.
// patterns.insert<ReductionMaskConversion<hlfir::MaxvalOp>>(context);
// patterns.insert<ReductionMaskConversion<hlfir::MinvalOp>>(context);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR optimized bufferization");
signalPassFailure();
}
}
};
} // namespace