llvm/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp

//===- LowerHLFIROrderedAssignments.cpp - Lower HLFIR ordered assignments -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file defines a pass to lower HLFIR ordered assignments.
// Ordered assignments are all the operations with the
// OrderedAssignmentTreeOpInterface that implements user defined assignments,
// assignment to vector subscripted entities, and assignments inside forall and
// where.
// The pass lowers these operations to regular hlfir.assign, loops and, if
// needed, introduces temporary storage to fulfill Fortran semantics.
//
// For each rewrite, an analysis builds an evaluation schedule, and then the
// new code is generated by following the evaluation schedule.
//===----------------------------------------------------------------------===//

#include "ScheduleOrderedAssignments.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/TemporaryStorage.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"

namespace hlfir {
#define GEN_PASS_DEF_LOWERHLFIRORDEREDASSIGNMENTS
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir

#define DEBUG_TYPE "flang-ordered-assignment"

// Test option only to test the scheduling part only (operations are erased
// without codegen). The only goal is to allow printing and testing the debug
// info.
static llvm::cl::opt<bool> dbgScheduleOnly(
    "flang-dbg-order-assignment-schedule-only",
    llvm::cl::desc("Only run ordered assignment scheduling with no codegen"),
    llvm::cl::init(false));

namespace {

/// Structure that represents a masked expression being lowered. Masked
/// expressions are any expressions inside an hlfir.where. As described in
/// Fortran 2018 section 10.2.3.2, the evaluation of the elemental parts of such
/// expressions must be masked, while the evaluation of none elemental parts
/// must not be masked. This structure analyzes the region evaluating the
/// expression and allows splitting the generation of the none elemental part
/// from the elemental part.
struct MaskedArrayExpr {
  MaskedArrayExpr(mlir::Location loc, mlir::Region &region,
                  bool isOuterMaskExpr);

  /// Generate the none elemental part. Must be called outside of the
  /// loops created for the WHERE construct.
  void generateNoneElementalPart(fir::FirOpBuilder &builder,
                                 mlir::IRMapping &mapper);

  /// Methods below can only be called once generateNoneElementalPart has been
  /// called.

  /// Return the shape of the expression.
  mlir::Value generateShape(fir::FirOpBuilder &builder,
                            mlir::IRMapping &mapper);
  /// Return the value of an element value for this expression given the current
  /// where loop indices.
  mlir::Value generateElementalParts(fir::FirOpBuilder &builder,
                                     mlir::ValueRange oneBasedIndices,
                                     mlir::IRMapping &mapper);
  /// Generate the cleanup for the none elemental parts, if any. This must be
  /// called after the loops created for the WHERE construct.
  void generateNoneElementalCleanupIfAny(fir::FirOpBuilder &builder,
                                         mlir::IRMapping &mapper);

  /// Helper to clone the clean-ups of the masked expr region terminator.
  /// This is called outside of the loops for the initial mask, and inside
  /// the loops for the other masked expressions.
  mlir::Operation *generateMaskedExprCleanUps(fir::FirOpBuilder &builder,
                                              mlir::IRMapping &mapper);

  mlir::Location loc;
  mlir::Region &region;
  /// Set of operations that form the elemental parts of the
  /// expression evaluation. These are the hlfir.elemental and
  /// hlfir.elemental_addr that form the elemental tree producing
  /// the expression value. hlfir.elemental that produce values
  /// used inside transformational operations are not part of this set.
  llvm::SmallSet<mlir::Operation *, 4> elementalParts{};
  /// Was generateNoneElementalPart called?
  bool noneElementalPartWasGenerated = false;
  /// Is this expression the mask expression of the outer where statement?
  /// It is special because its evaluation is not masked by anything yet.
  bool isOuterMaskExpr = false;
};
} // namespace

namespace {
/// Structure that visits an ordered assignment tree and generates code for
/// it according to a schedule.
class OrderedAssignmentRewriter {
public:
  OrderedAssignmentRewriter(fir::FirOpBuilder &builder,
                            hlfir::OrderedAssignmentTreeOpInterface root)
      : builder{builder}, root{root} {}

  /// Generate code for the current run of the schedule.
  void lowerRun(hlfir::Run &run) {
    currentRun = &run;
    walk(root);
    currentRun = nullptr;
    assert(constructStack.empty() && "must exit constructs after a run");
    mapper.clear();
    savedInCurrentRunBeforeUse.clear();
  }

  /// After all run have been lowered, clean-up all the temporary
  /// storage that were created (do not call final routines).
  void cleanupSavedEntities() {
    for (auto &temp : savedEntities)
      temp.second.destroy(root.getLoc(), builder);
  }

  /// Lowered value for an expression, and the original hlfir.yield if any
  /// clean-up needs to be cloned after usage.
  using ValueAndCleanUp = std::pair<mlir::Value, std::optional<hlfir::YieldOp>>;

private:
  /// Walk the part of an order assignment tree node that needs
  /// to be evaluated in the current run.
  void walk(hlfir::OrderedAssignmentTreeOpInterface node);

  /// Generate code when entering a given ordered assignment node.
  void pre(hlfir::ForallOp forallOp);
  void pre(hlfir::ForallIndexOp);
  void pre(hlfir::ForallMaskOp);
  void pre(hlfir::WhereOp whereOp);
  void pre(hlfir::ElseWhereOp elseWhereOp);
  void pre(hlfir::RegionAssignOp);

  /// Generate code when leaving a given ordered assignment node.
  void post(hlfir::ForallOp);
  void post(hlfir::ForallMaskOp);
  void post(hlfir::WhereOp);
  void post(hlfir::ElseWhereOp);
  /// Enter (and maybe create) the fir.if else block of an ElseWhereOp,
  /// but do not generate the elswhere mask or the new fir.if.
  void enterElsewhere(hlfir::ElseWhereOp);

  /// Are there any leaf region in the node that must be saved in the current
  /// run?
  bool mustSaveRegionIn(
      hlfir::OrderedAssignmentTreeOpInterface node,
      llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const;
  /// Should this node be evaluated in the current run? Saving a region in a
  /// node does not imply the node needs to be evaluated.
  bool
  isRequiredInCurrentRun(hlfir::OrderedAssignmentTreeOpInterface node) const;

  /// Generate a scalar value yielded by an ordered assignment tree region.
  /// If the value was not saved in a previous run, this clone the region
  /// code, except the final yield, at the current execution point.
  /// If the value was saved in a previous run, this fetches the saved value
  /// from the temporary storage and returns the value.
  /// Inside Forall, the value will be hoisted outside of the forall loops if
  /// it does not depend on the forall indices.
  /// An optional type can be provided to get a value from a specific type
  /// (the cast will be hoisted if the computation is hoisted).
  mlir::Value generateYieldedScalarValue(
      mlir::Region &region,
      std::optional<mlir::Type> castToType = std::nullopt);

  /// Generate an entity yielded by an ordered assignment tree region, and
  /// optionally return the (uncloned) yield if there is any clean-up that
  /// should be done after using the entity. Like, generateYieldedScalarValue,
  /// this will return the saved value if the region was saved in a previous
  /// run.
  ValueAndCleanUp
  generateYieldedEntity(mlir::Region &region,
                        std::optional<mlir::Type> castToType = std::nullopt);

  struct LhsValueAndCleanUp {
    mlir::Value lhs;
    std::optional<hlfir::YieldOp> elementalCleanup;
    mlir::Region *nonElementalCleanup = nullptr;
    std::optional<hlfir::LoopNest> vectorSubscriptLoopNest;
    std::optional<mlir::Value> vectorSubscriptShape;
  };

  /// Generate the left-hand side. If the left-hand side is vector
  /// subscripted (hlfir.elemental_addr), this will create a loop nest
  /// (unless it was already created by a WHERE mask) and return the
  /// element address.
  LhsValueAndCleanUp
  generateYieldedLHS(mlir::Location loc, mlir::Region &lhsRegion,
                     std::optional<hlfir::Entity> loweredRhs = std::nullopt);

  /// If \p maybeYield is present and has a clean-up, generate the clean-up
  /// at the current insertion point (by cloning).
  void generateCleanupIfAny(std::optional<hlfir::YieldOp> maybeYield);
  void generateCleanupIfAny(mlir::Region *cleanupRegion);

  /// Generate a masked entity. This can only be called when whereLoopNest was
  /// set (When an hlfir.where is being visited).
  /// This method returns the scalar element (that may have been previously
  /// saved) for the current indices inside the where loop.
  mlir::Value generateMaskedEntity(mlir::Location loc, mlir::Region &region) {
    MaskedArrayExpr maskedExpr(loc, region, /*isOuterMaskExpr=*/!whereLoopNest);
    return generateMaskedEntity(maskedExpr);
  }
  mlir::Value generateMaskedEntity(MaskedArrayExpr &maskedExpr);

  /// Create a fir.if at the current position inside the where loop nest
  /// given the element value of a mask.
  void generateMaskIfOp(mlir::Value cdt);

  /// Save a value for subsequent runs.
  void generateSaveEntity(hlfir::SaveEntity savedEntity,
                          bool willUseSavedEntityInSameRun);
  void saveLeftHandSide(hlfir::SaveEntity savedEntity,
                        hlfir::RegionAssignOp regionAssignOp);

  /// Get a value if it was saved in this run or a previous run. Returns
  /// nullopt if it has not been saved.
  std::optional<ValueAndCleanUp> getIfSaved(mlir::Region &region);

  /// Generate code before the loop nest for the current run, if any.
  void doBeforeLoopNest(const std::function<void()> &callback) {
    if (constructStack.empty()) {
      callback();
      return;
    }
    auto insertionPoint = builder.saveInsertionPoint();
    builder.setInsertionPoint(constructStack[0]);
    callback();
    builder.restoreInsertionPoint(insertionPoint);
  }

  /// Can the current loop nest iteration number be computed? For simplicity,
  /// this is true if and only if all the bounds and steps of the fir.do_loop
  /// nest dominates the outer loop. The argument is filled with the current
  /// loop nest on success.
  bool currentLoopNestIterationNumberCanBeComputed(
      llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest);

  template <typename T>
  fir::factory::TemporaryStorage *insertSavedEntity(mlir::Region &region,
                                                    T &&temp) {
    auto inserted =
        savedEntities.insert(std::make_pair(&region, std::forward<T>(temp)));
    assert(inserted.second && "temp must have been emplaced");
    return &inserted.first->second;
  }

  fir::FirOpBuilder &builder;

  /// Map containing the mapping between the original order assignment tree
  /// operations and the operations that have been cloned in the current run.
  /// It is reset between two runs.
  mlir::IRMapping mapper;
  /// Dominance info is used to determine if inner loop bounds are all computed
  /// before outer loop for the current loop. It does not need to be reset
  /// between runs.
  mlir::DominanceInfo dominanceInfo;
  /// Construct stack in the current run. This allows setting back the insertion
  /// point correctly when leaving a node that requires a fir.do_loop or fir.if
  /// operation.
  llvm::SmallVector<mlir::Operation *> constructStack;
  /// Current where loop nest, if any.
  std::optional<hlfir::LoopNest> whereLoopNest;

  /// Map of temporary storage to keep track of saved entity once the run
  /// that saves them has been lowered. It is kept in-between runs.
  /// llvm::MapVector is used to guarantee deterministic order
  /// of iterating through savedEntities (e.g. for generating
  /// destruction code for the temporary storages).
  llvm::MapVector<mlir::Region *, fir::factory::TemporaryStorage> savedEntities;
  /// Map holding the values that were saved in the current run and that also
  /// need to be used (because their construct will be visited). It is reset
  /// after each run. It avoids having to store and fetch in the temporary
  /// during the same run, which would require the temporary to have different
  /// fetching and storing counters.
  llvm::DenseMap<mlir::Region *, ValueAndCleanUp> savedInCurrentRunBeforeUse;

  /// Root of the order assignment tree being lowered.
  hlfir::OrderedAssignmentTreeOpInterface root;
  /// Pointer to the current run of the schedule being lowered.
  hlfir::Run *currentRun = nullptr;

  /// When allocating temporary storage inlined, indicate if the storage should
  /// be heap or stack allocated. Temporary allocated with the runtime are heap
  /// allocated by the runtime.
  bool allocateOnHeap = true;
};
} // namespace

void OrderedAssignmentRewriter::walk(
    hlfir::OrderedAssignmentTreeOpInterface node) {
  bool mustVisit =
      isRequiredInCurrentRun(node) || mlir::isa<hlfir::ForallIndexOp>(node);
  llvm::SmallVector<hlfir::SaveEntity> saveEntities;
  mlir::Operation *nodeOp = node.getOperation();
  if (mustSaveRegionIn(node, saveEntities)) {
    mlir::IRRewriter::InsertPoint insertionPoint;
    if (auto elseWhereOp = mlir::dyn_cast<hlfir::ElseWhereOp>(nodeOp)) {
      // ElseWhere mask to save must be evaluated inside the fir.if else
      // for the previous where/elsewehere (its evaluation must be
      // masked by the "pending control mask").
      insertionPoint = builder.saveInsertionPoint();
      enterElsewhere(elseWhereOp);
    }
    for (hlfir::SaveEntity saveEntity : saveEntities)
      generateSaveEntity(saveEntity, mustVisit);
    if (insertionPoint.isSet())
      builder.restoreInsertionPoint(insertionPoint);
  }
  if (mustVisit) {
    llvm::TypeSwitch<mlir::Operation *, void>(nodeOp)
        .Case<hlfir::ForallOp, hlfir::ForallIndexOp, hlfir::ForallMaskOp,
              hlfir::RegionAssignOp, hlfir::WhereOp, hlfir::ElseWhereOp>(
            [&](auto concreteOp) { pre(concreteOp); })
        .Default([](auto) {});
    if (auto *body = node.getSubTreeRegion()) {
      for (mlir::Operation &op : body->getOps())
        if (auto subNode =
                mlir::dyn_cast<hlfir::OrderedAssignmentTreeOpInterface>(op))
          walk(subNode);
      llvm::TypeSwitch<mlir::Operation *, void>(nodeOp)
          .Case<hlfir::ForallOp, hlfir::ForallMaskOp, hlfir::WhereOp,
                hlfir::ElseWhereOp>([&](auto concreteOp) { post(concreteOp); })
          .Default([](auto) {});
    }
  }
}

void OrderedAssignmentRewriter::pre(hlfir::ForallOp forallOp) {
  /// Create a fir.do_loop given the hlfir.forall control values.
  mlir::Type idxTy = builder.getIndexType();
  mlir::Location loc = forallOp.getLoc();
  mlir::Value lb = generateYieldedScalarValue(forallOp.getLbRegion(), idxTy);
  mlir::Value ub = generateYieldedScalarValue(forallOp.getUbRegion(), idxTy);
  mlir::Value step;
  if (forallOp.getStepRegion().empty()) {
    auto insertionPoint = builder.saveInsertionPoint();
    if (!constructStack.empty())
      builder.setInsertionPoint(constructStack[0]);
    step = builder.createIntegerConstant(loc, idxTy, 1);
    if (!constructStack.empty())
      builder.restoreInsertionPoint(insertionPoint);
  } else {
    step = generateYieldedScalarValue(forallOp.getStepRegion(), idxTy);
  }
  auto doLoop = builder.create<fir::DoLoopOp>(loc, lb, ub, step);
  builder.setInsertionPointToStart(doLoop.getBody());
  mlir::Value oldIndex = forallOp.getForallIndexValue();
  mlir::Value newIndex =
      builder.createConvert(loc, oldIndex.getType(), doLoop.getInductionVar());
  mapper.map(oldIndex, newIndex);
  constructStack.push_back(doLoop);
}

void OrderedAssignmentRewriter::post(hlfir::ForallOp) {
  assert(!constructStack.empty() && "must contain a loop");
  builder.setInsertionPointAfter(constructStack.pop_back_val());
}

void OrderedAssignmentRewriter::pre(hlfir::ForallIndexOp forallIndexOp) {
  mlir::Location loc = forallIndexOp.getLoc();
  mlir::Type intTy = fir::unwrapRefType(forallIndexOp.getType());
  mlir::Value indexVar =
      builder.createTemporary(loc, intTy, forallIndexOp.getName());
  mlir::Value newVal = mapper.lookupOrDefault(forallIndexOp.getIndex());
  builder.createStoreWithConvert(loc, newVal, indexVar);
  mapper.map(forallIndexOp, indexVar);
}

void OrderedAssignmentRewriter::pre(hlfir::ForallMaskOp forallMaskOp) {
  mlir::Location loc = forallMaskOp.getLoc();
  mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion(),
                                                builder.getI1Type());
  auto ifOp = builder.create<fir::IfOp>(loc, std::nullopt, mask, false);
  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
  constructStack.push_back(ifOp);
}

void OrderedAssignmentRewriter::post(hlfir::ForallMaskOp forallMaskOp) {
  assert(!constructStack.empty() && "must contain an ifop");
  builder.setInsertionPointAfter(constructStack.pop_back_val());
}

/// Convert an entity to the type of a given mold.
/// This is intended to help with cases where hlfir entity is a value while
/// it must be used as a variable or vice-versa. These mismatches may occur
/// between the type of user defined assignment block arguments and the actual
/// argument that was lowered for them. The actual may be an in-memory copy
/// while the block argument expects an hlfir.expr.
static hlfir::Entity
convertToMoldType(mlir::Location loc, fir::FirOpBuilder &builder,
                  hlfir::Entity input, hlfir::Entity mold,
                  llvm::SmallVectorImpl<hlfir::CleanupFunction> &cleanups) {
  if (input.getType() == mold.getType())
    return input;
  fir::FirOpBuilder *b = &builder;
  if (input.isVariable() && mold.isValue()) {
    if (fir::isa_trivial(mold.getType())) {
      // fir.ref<T> to T.
      mlir::Value load = builder.create<fir::LoadOp>(loc, input);
      return hlfir::Entity{builder.createConvert(loc, mold.getType(), load)};
    }
    // fir.ref<T> to hlfir.expr<T>.
    mlir::Value asExpr = builder.create<hlfir::AsExprOp>(loc, input);
    if (asExpr.getType() != mold.getType())
      TODO(loc, "hlfir.expr conversion");
    cleanups.emplace_back([=]() { b->create<hlfir::DestroyOp>(loc, asExpr); });
    return hlfir::Entity{asExpr};
  }
  if (input.isValue() && mold.isVariable()) {
    // T to fir.ref<T>, or hlfir.expr<T> to fir.ref<T>.
    hlfir::AssociateOp associate = hlfir::genAssociateExpr(
        loc, builder, input, mold.getFortranElementType(), ".tmp.val2ref");
    cleanups.emplace_back(
        [=]() { b->create<hlfir::EndAssociateOp>(loc, associate); });
    return hlfir::Entity{associate.getBase()};
  }
  // Variable to Variable mismatch (e.g., fir.heap<T> vs fir.ref<T>), or value
  // to Value mismatch (e.g. i1 vs fir.logical<4>).
  if (mlir::isa<fir::BaseBoxType>(mold.getType()) &&
      !mlir::isa<fir::BaseBoxType>(input.getType())) {
    // An entity may have have been saved without descriptor while the original
    // value had a descriptor (e.g., it was not contiguous).
    auto emboxed = hlfir::convertToBox(loc, builder, input, mold.getType());
    assert(!emboxed.second && "temp should already be in memory");
    input = hlfir::Entity{fir::getBase(emboxed.first)};
  }
  return hlfir::Entity{builder.createConvert(loc, mold.getType(), input)};
}

void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
  mlir::Location loc = regionAssignOp.getLoc();
  std::optional<hlfir::LoopNest> elementalLoopNest;
  auto [rhsValue, oldRhsYield] =
      generateYieldedEntity(regionAssignOp.getRhsRegion());
  hlfir::Entity rhsEntity{rhsValue};
  LhsValueAndCleanUp loweredLhs =
      generateYieldedLHS(loc, regionAssignOp.getLhsRegion(), rhsEntity);
  hlfir::Entity lhsEntity{loweredLhs.lhs};
  if (loweredLhs.vectorSubscriptLoopNest)
    rhsEntity = hlfir::getElementAt(
        loc, builder, rhsEntity,
        loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
  if (!regionAssignOp.getUserDefinedAssignment().empty()) {
    hlfir::Entity userAssignLhs{regionAssignOp.getUserAssignmentLhs()};
    hlfir::Entity userAssignRhs{regionAssignOp.getUserAssignmentRhs()};
    std::optional<hlfir::LoopNest> elementalLoopNest;
    if (lhsEntity.isArray() && userAssignLhs.isScalar()) {
      // Elemental assignment with array argument (the RHS cannot be an array
      // if the LHS is not).
      mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
      elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
      builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
      lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
                                      elementalLoopNest->oneBasedIndices);
      rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
                                      elementalLoopNest->oneBasedIndices);
    }

    llvm::SmallVector<hlfir::CleanupFunction, 2> argConversionCleanups;
    lhsEntity = convertToMoldType(loc, builder, lhsEntity, userAssignLhs,
                                  argConversionCleanups);
    rhsEntity = convertToMoldType(loc, builder, rhsEntity, userAssignRhs,
                                  argConversionCleanups);
    mapper.map(userAssignLhs, lhsEntity);
    mapper.map(userAssignRhs, rhsEntity);
    for (auto &op :
         regionAssignOp.getUserDefinedAssignment().front().without_terminator())
      (void)builder.clone(op, mapper);
    for (auto &cleanupConversion : argConversionCleanups)
      cleanupConversion();
    if (elementalLoopNest)
      builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
  } else {
    // TODO: preserve allocatable assignment aspects for forall once
    // they are conveyed in hlfir.region_assign.
    builder.create<hlfir::AssignOp>(loc, rhsEntity, lhsEntity);
  }
  generateCleanupIfAny(loweredLhs.elementalCleanup);
  if (loweredLhs.vectorSubscriptLoopNest)
    builder.setInsertionPointAfter(
        loweredLhs.vectorSubscriptLoopNest->outerLoop);
  generateCleanupIfAny(oldRhsYield);
  generateCleanupIfAny(loweredLhs.nonElementalCleanup);
}

void OrderedAssignmentRewriter::generateMaskIfOp(mlir::Value cdt) {
  mlir::Location loc = cdt.getLoc();
  cdt = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{cdt});
  cdt = builder.createConvert(loc, builder.getI1Type(), cdt);
  auto ifOp = builder.create<fir::IfOp>(cdt.getLoc(), std::nullopt, cdt,
                                        /*withElseRegion=*/false);
  constructStack.push_back(ifOp.getOperation());
  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
}

void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
  mlir::Location loc = whereOp.getLoc();
  if (!whereLoopNest) {
    // This is the top-level WHERE. Start a loop nest iterating on the shape of
    // the where mask.
    if (auto maybeSaved = getIfSaved(whereOp.getMaskRegion())) {
      // Use the saved value to get the shape and condition element.
      hlfir::Entity savedMask{maybeSaved->first};
      mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
      whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
      constructStack.push_back(whereLoopNest->outerLoop.getOperation());
      builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
      mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
                                            whereLoopNest->oneBasedIndices);
      generateMaskIfOp(cdt);
      if (maybeSaved->second) {
        // If this is the same run as the one that saved the value, the clean-up
        // was left-over to be done now.
        auto insertionPoint = builder.saveInsertionPoint();
        builder.setInsertionPointAfter(whereLoopNest->outerLoop);
        generateCleanupIfAny(maybeSaved->second);
        builder.restoreInsertionPoint(insertionPoint);
      }
      return;
    }
    // The mask was not evaluated yet or can be safely re-evaluated.
    MaskedArrayExpr mask(loc, whereOp.getMaskRegion(),
                         /*isOuterMaskExpr=*/true);
    mask.generateNoneElementalPart(builder, mapper);
    mlir::Value shape = mask.generateShape(builder, mapper);
    whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
    constructStack.push_back(whereLoopNest->outerLoop.getOperation());
    builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
    mlir::Value cdt = generateMaskedEntity(mask);
    generateMaskIfOp(cdt);
    return;
  }
  // Where Loops have been already created by a parent WHERE.
  // Generate a fir.if with the value of the current element of the mask
  // inside the loops. The case where the mask was saved is handled in the
  // generateYieldedScalarValue call.
  mlir::Value cdt = generateYieldedScalarValue(whereOp.getMaskRegion());
  generateMaskIfOp(cdt);
}

void OrderedAssignmentRewriter::post(hlfir::WhereOp whereOp) {
  assert(!constructStack.empty() && "must contain a fir.if");
  builder.setInsertionPointAfter(constructStack.pop_back_val());
  // If all where/elsewhere fir.if have been popped, this is the outer whereOp,
  // and the where loop must be exited.
  assert(!constructStack.empty() && "must contain a  fir.do_loop or fir.if");
  if (mlir::isa<fir::DoLoopOp>(constructStack.back())) {
    builder.setInsertionPointAfter(constructStack.pop_back_val());
    whereLoopNest.reset();
  }
}

void OrderedAssignmentRewriter::enterElsewhere(hlfir::ElseWhereOp elseWhereOp) {
  // Create an "else" region for the current where/elsewhere fir.if.
  auto ifOp = mlir::dyn_cast<fir::IfOp>(constructStack.back());
  assert(ifOp && "must be an if");
  if (ifOp.getElseRegion().empty()) {
    mlir::Location loc = elseWhereOp.getLoc();
    builder.createBlock(&ifOp.getElseRegion());
    auto end = builder.create<fir::ResultOp>(loc);
    builder.setInsertionPoint(end);
  } else {
    builder.setInsertionPoint(&ifOp.getElseRegion().back().back());
  }
}

void OrderedAssignmentRewriter::pre(hlfir::ElseWhereOp elseWhereOp) {
  enterElsewhere(elseWhereOp);
  if (elseWhereOp.getMaskRegion().empty())
    return;
  // Create new nested fir.if with elsewhere mask if any.
  mlir::Value cdt = generateYieldedScalarValue(elseWhereOp.getMaskRegion());
  generateMaskIfOp(cdt);
}

void OrderedAssignmentRewriter::post(hlfir::ElseWhereOp elseWhereOp) {
  // Exit ifOp that was created for the elseWhereOp mask, if any.
  if (elseWhereOp.getMaskRegion().empty())
    return;
  assert(!constructStack.empty() && "must contain a fir.if");
  builder.setInsertionPointAfter(constructStack.pop_back_val());
}

/// Is this value a Forall index?
/// Forall index are block arguments of hlfir.forall body, or the result
/// of hlfir.forall_index.
static bool isForallIndex(mlir::Value value) {
  if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) {
    if (mlir::Block *block = blockArg.getOwner())
      return block->isEntryBlock() &&
             mlir::isa_and_nonnull<hlfir::ForallOp>(block->getParentOp());
    return false;
  }
  return value.getDefiningOp<hlfir::ForallIndexOp>();
}

static OrderedAssignmentRewriter::ValueAndCleanUp
castIfNeeded(mlir::Location loc, fir::FirOpBuilder &builder,
             OrderedAssignmentRewriter::ValueAndCleanUp valueAndCleanUp,
             std::optional<mlir::Type> castToType) {
  if (!castToType.has_value())
    return valueAndCleanUp;
  mlir::Value cast =
      builder.createConvert(loc, *castToType, valueAndCleanUp.first);
  return {cast, valueAndCleanUp.second};
}

std::optional<OrderedAssignmentRewriter::ValueAndCleanUp>
OrderedAssignmentRewriter::getIfSaved(mlir::Region &region) {
  mlir::Location loc = region.getParentOp()->getLoc();
  // If the region was saved in the same run, use the value that was evaluated
  // instead of fetching the temp, and do clean-up, if any, that were delayed.
  // This is done to avoid requiring the temporary stack to have different
  // fetching and storing counters, and also because it produces slightly better
  // code.
  if (auto savedInSameRun = savedInCurrentRunBeforeUse.find(&region);
      savedInSameRun != savedInCurrentRunBeforeUse.end())
    return savedInSameRun->second;
  // If the region was saved in a previous run, fetch the saved value.
  if (auto temp = savedEntities.find(&region); temp != savedEntities.end()) {
    doBeforeLoopNest([&]() { temp->second.resetFetchPosition(loc, builder); });
    return ValueAndCleanUp{temp->second.fetch(loc, builder), std::nullopt};
  }
  return std::nullopt;
}

static hlfir::YieldOp getYield(mlir::Region &region) {
  auto yield = mlir::dyn_cast_or_null<hlfir::YieldOp>(
      region.back().getOperations().back());
  assert(yield && "region computing entities must end with a YieldOp");
  return yield;
}

OrderedAssignmentRewriter::ValueAndCleanUp
OrderedAssignmentRewriter::generateYieldedEntity(
    mlir::Region &region, std::optional<mlir::Type> castToType) {
  mlir::Location loc = region.getParentOp()->getLoc();
  if (auto maybeValueAndCleanUp = getIfSaved(region))
    return castIfNeeded(loc, builder, *maybeValueAndCleanUp, castToType);
  // Otherwise, evaluate the region now.

  // Masked expression must not evaluate the elemental parts that are masked,
  // they have custom code generation.
  if (whereLoopNest.has_value()) {
    mlir::Value maskedValue = generateMaskedEntity(loc, region);
    return castIfNeeded(loc, builder, {maskedValue, std::nullopt}, castToType);
  }

  assert(region.hasOneBlock() && "region must contain one block");
  auto oldYield = getYield(region);
  mlir::Block::OpListType &ops = region.back().getOperations();

  // Inside Forall, scalars that do not depend on forall indices can be hoisted
  // here because their evaluation is required to only call pure procedures, and
  // if they depend on a variable previously assigned to in a forall assignment,
  // this assignment must have been scheduled in a previous run. Hoisting of
  // scalars is done here to help creating simple temporary storage if needed.
  // Inner forall bounds can often be hoisted, and this allows computing the
  // total number of iterations to create temporary storages.
  bool hoistComputation = false;
  if (fir::isa_trivial(oldYield.getEntity().getType()) &&
      !constructStack.empty()) {
    hoistComputation = true;
    for (mlir::Operation &op : ops)
      if (llvm::any_of(op.getOperands(), [](mlir::Value value) {
            return isForallIndex(value);
          })) {
        hoistComputation = false;
        break;
      }
  }
  auto insertionPoint = builder.saveInsertionPoint();
  if (hoistComputation)
    builder.setInsertionPoint(constructStack[0]);

  // Clone all operations except the final hlfir.yield.
  assert(!ops.empty() && "yield block cannot be empty");
  auto end = ops.end();
  for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt)
    (void)builder.clone(*opIt, mapper);
  // Get the value for the yielded entity, it may be the result of an operation
  // that was cloned, or it may be the same as the previous value if the yield
  // operand was created before the ordered assignment tree.
  mlir::Value newEntity = mapper.lookupOrDefault(oldYield.getEntity());
  if (castToType.has_value())
    newEntity =
        builder.createConvert(newEntity.getLoc(), *castToType, newEntity);

  if (hoistComputation) {
    // Hoisted trivial scalars clean-up can be done right away, the value is
    // in registers.
    generateCleanupIfAny(oldYield);
    builder.restoreInsertionPoint(insertionPoint);
    return {newEntity, std::nullopt};
  }
  if (oldYield.getCleanup().empty())
    return {newEntity, std::nullopt};
  return {newEntity, oldYield};
}

mlir::Value OrderedAssignmentRewriter::generateYieldedScalarValue(
    mlir::Region &region, std::optional<mlir::Type> castToType) {
  mlir::Location loc = region.getParentOp()->getLoc();
  auto [value, maybeYield] = generateYieldedEntity(region, castToType);
  value = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{value});
  assert(fir::isa_trivial(value.getType()) && "not a trivial scalar value");
  generateCleanupIfAny(maybeYield);
  return value;
}

OrderedAssignmentRewriter::LhsValueAndCleanUp
OrderedAssignmentRewriter::generateYieldedLHS(
    mlir::Location loc, mlir::Region &lhsRegion,
    std::optional<hlfir::Entity> loweredRhs) {
  LhsValueAndCleanUp loweredLhs;
  hlfir::ElementalAddrOp elementalAddrLhs =
      mlir::dyn_cast<hlfir::ElementalAddrOp>(lhsRegion.back().back());
  if (auto temp = savedEntities.find(&lhsRegion); temp != savedEntities.end()) {
    // The LHS address was computed and saved in a previous run. Fetch it.
    doBeforeLoopNest([&]() { temp->second.resetFetchPosition(loc, builder); });
    if (elementalAddrLhs && !whereLoopNest) {
      // Vector subscripted designator address are saved element by element.
      // If no "elemental" loops have been created yet, the shape of the
      // RHS, if it is an array can be used, or the shape of the vector
      // subscripted designator must be retrieved to generate the "elemental"
      // loop nest.
      if (loweredRhs && loweredRhs->isArray()) {
        // The RHS shape can be used to create the elemental loops and avoid
        // saving the LHS shape.
        loweredLhs.vectorSubscriptShape =
            hlfir::genShape(loc, builder, *loweredRhs);
      } else {
        // If the shape cannot be retrieved from the RHS, it must have been
        // saved. Get it from the temporary.
        auto &vectorTmp =
            temp->second.cast<fir::factory::AnyVectorSubscriptStack>();
        loweredLhs.vectorSubscriptShape = vectorTmp.fetchShape(loc, builder);
      }
      loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
          loc, builder, loweredLhs.vectorSubscriptShape.value());
      builder.setInsertionPointToStart(
          loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
    }
    loweredLhs.lhs = temp->second.fetch(loc, builder);
    return loweredLhs;
  }
  // The LHS has not yet been evaluated and saved. Evaluate it now.
  if (elementalAddrLhs && !whereLoopNest) {
    // This is a vector subscripted entity. The address of elements must
    // be returned. If no "elemental" loops have been created for a WHERE,
    // create them now based on the vector subscripted designator shape.
    for (auto &op : lhsRegion.front().without_terminator())
      (void)builder.clone(op, mapper);
    loweredLhs.vectorSubscriptShape =
        mapper.lookupOrDefault(elementalAddrLhs.getShape());
    loweredLhs.vectorSubscriptLoopNest =
        hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
                           !elementalAddrLhs.isOrdered());
    builder.setInsertionPointToStart(
        loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
    mapper.map(elementalAddrLhs.getIndices(),
               loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
    for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
      (void)builder.clone(op, mapper);
    loweredLhs.elementalCleanup = elementalAddrLhs.getYieldOp();
    loweredLhs.lhs =
        mapper.lookupOrDefault(loweredLhs.elementalCleanup->getEntity());
  } else {
    // This is a designator without vector subscripts. Generate it as
    // it is done for other entities.
    auto [lhs, yield] = generateYieldedEntity(lhsRegion);
    loweredLhs.lhs = lhs;
    if (yield && !yield->getCleanup().empty())
      loweredLhs.nonElementalCleanup = &yield->getCleanup();
  }
  return loweredLhs;
}

mlir::Value
OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
  assert(whereLoopNest.has_value() && "must be inside WHERE loop nest");
  auto insertionPoint = builder.saveInsertionPoint();
  if (!maskedExpr.noneElementalPartWasGenerated) {
    // Generate none elemental part before the where loops (but inside the
    // current forall loops if any).
    builder.setInsertionPoint(whereLoopNest->outerLoop);
    maskedExpr.generateNoneElementalPart(builder, mapper);
  }
  // Generate the none elemental part cleanup after the where loops.
  builder.setInsertionPointAfter(whereLoopNest->outerLoop);
  maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
  // Generate the value of the current element for the masked expression
  // at the current insertion point (inside the where loops, and any fir.if
  // generated for previous masks).
  builder.restoreInsertionPoint(insertionPoint);
  mlir::Value scalar = maskedExpr.generateElementalParts(
      builder, whereLoopNest->oneBasedIndices, mapper);
  /// Generate cleanups for the elemental parts inside the loops (setting the
  /// location so that the assignment will be generated before the cleanups).
  if (!maskedExpr.isOuterMaskExpr)
    if (mlir::Operation *firstCleanup =
            maskedExpr.generateMaskedExprCleanUps(builder, mapper))
      builder.setInsertionPoint(firstCleanup);
  return scalar;
}

void OrderedAssignmentRewriter::generateCleanupIfAny(
    std::optional<hlfir::YieldOp> maybeYield) {
  if (maybeYield.has_value())
    generateCleanupIfAny(&maybeYield->getCleanup());
}
void OrderedAssignmentRewriter::generateCleanupIfAny(
    mlir::Region *cleanupRegion) {
  if (cleanupRegion && !cleanupRegion->empty()) {
    assert(cleanupRegion->hasOneBlock() && "region must contain one block");
    for (auto &op : cleanupRegion->back().without_terminator())
      builder.clone(op, mapper);
  }
}

bool OrderedAssignmentRewriter::mustSaveRegionIn(
    hlfir::OrderedAssignmentTreeOpInterface node,
    llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const {
  for (auto &action : currentRun->actions)
    if (hlfir::SaveEntity *savedEntity =
            std::get_if<hlfir::SaveEntity>(&action))
      if (node.getOperation() == savedEntity->yieldRegion->getParentOp())
        saveEntities.push_back(*savedEntity);
  return !saveEntities.empty();
}

bool OrderedAssignmentRewriter::isRequiredInCurrentRun(
    hlfir::OrderedAssignmentTreeOpInterface node) const {
  // hlfir.forall_index do not contain saved regions/assignments,
  // but if their hlfir.forall parent was required, they are
  // required (the forall indices needs to be mapped).
  if (mlir::isa<hlfir::ForallIndexOp>(node))
    return true;
  for (auto &action : currentRun->actions)
    if (hlfir::SaveEntity *savedEntity =
            std::get_if<hlfir::SaveEntity>(&action)) {
      // A SaveEntity action does not require evaluating the node that contains
      // it, but it requires to evaluate all the parents of the nodes that
      // contains it. For instance, an saving a bound in hlfir.forall B does not
      // require creating the loops for B, but it requires creating the loops
      // for any forall parent A of the forall B.
      if (node->isProperAncestor(savedEntity->yieldRegion->getParentOp()))
        return true;
    } else {
      auto assign = std::get<hlfir::RegionAssignOp>(action);
      if (node->isAncestor(assign.getOperation()))
        return true;
    }
  return false;
}

/// Is the apply using all the elemental indices in order?
static bool isInOrderApply(hlfir::ApplyOp apply,
                           hlfir::ElementalOpInterface elemental) {
  mlir::Region::BlockArgListType elementalIndices = elemental.getIndices();
  if (elementalIndices.size() != apply.getIndices().size())
    return false;
  for (auto [elementalIdx, applyIdx] :
       llvm::zip(elementalIndices, apply.getIndices()))
    if (elementalIdx != applyIdx)
      return false;
  return true;
}

/// Gather the tree of hlfir::ElementalOpInterface use-def, if any, starting
/// from \p elemental, which may be a nullptr.
static void
gatherElementalTree(hlfir::ElementalOpInterface elemental,
                    llvm::SmallPtrSetImpl<mlir::Operation *> &elementalOps,
                    bool isOutOfOrder) {
  if (elemental) {
    // Only inline an applied elemental that must be executed in order if the
    // applying indices are in order. An hlfir::Elemental may have been created
    // for a transformational like transpose, and Fortran 2018 standard
    // section 10.2.3.2, point 10 imply that impure elemental sub-expression
    // evaluations should not be masked if they are the arguments of
    // transformational expressions.
    if (isOutOfOrder && elemental.isOrdered())
      return;
    elementalOps.insert(elemental.getOperation());
    for (mlir::Operation &op : elemental.getElementalRegion().getOps())
      if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(op)) {
        bool isUnorderedApply =
            isOutOfOrder || !isInOrderApply(apply, elemental);
        auto maybeElemental =
            mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
                apply.getExpr().getDefiningOp());
        gatherElementalTree(maybeElemental, elementalOps, isUnorderedApply);
      }
  }
}

MaskedArrayExpr::MaskedArrayExpr(mlir::Location loc, mlir::Region &region,
                                 bool isOuterMaskExpr)
    : loc{loc}, region{region}, isOuterMaskExpr{isOuterMaskExpr} {
  mlir::Operation &terminator = region.back().back();
  if (auto elementalAddr =
          mlir::dyn_cast<hlfir::ElementalOpInterface>(terminator)) {
    // Vector subscripted designator (hlfir.elemental_addr terminator).
    gatherElementalTree(elementalAddr, elementalParts, /*isOutOfOrder=*/false);
    return;
  }
  // Try if elemental expression.
  mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
  auto maybeElemental = mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
      entity.getDefiningOp());
  gatherElementalTree(maybeElemental, elementalParts, /*isOutOfOrder=*/false);
}

void MaskedArrayExpr::generateNoneElementalPart(fir::FirOpBuilder &builder,
                                                mlir::IRMapping &mapper) {
  assert(!noneElementalPartWasGenerated &&
         "none elemental parts already generated");
  if (isOuterMaskExpr) {
    // The outer mask expression is actually not masked, it is dealt as
    // such so that its elemental part, if any, can be inlined in the WHERE
    // loops. But all of the operations outside of hlfir.elemental/
    // hlfir.elemental_addr must be emitted now because their value may be
    // required to deduce the mask shape and the WHERE loop bounds.
    for (mlir::Operation &op : region.back().without_terminator())
      if (!elementalParts.contains(&op))
        (void)builder.clone(op, mapper);
  } else {
    // For actual masked expressions, Fortran requires elemental expressions,
    // even the scalar ones that are not encoded with hlfir.elemental, to be
    // evaluated only when the mask is true. Blindly hoisting all scalar SSA
    // tree could be wrong if the scalar computation has side effects and
    // would never have been evaluated (e.g. division by zero) if the mask
    // is fully false. See F'2023 10.2.3.2 point 10.
    // Clone only the bodies of all hlfir.exactly_once operations, which contain
    // the evaluation of sub-expression tree whose root was a non elemental
    // function call at the Fortran level (the call itself may have been inlined
    // since). These must be evaluated only once as per F'2023 10.2.3.2 point 9.
    for (mlir::Operation &op : region.back().without_terminator())
      if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) {
        for (mlir::Operation &subOp :
             exactlyOnce.getBody().back().without_terminator())
          (void)builder.clone(subOp, mapper);
        mlir::Value oldYield = getYield(exactlyOnce.getBody()).getEntity();
        auto newYield = mapper.lookupOrDefault(oldYield);
        mapper.map(exactlyOnce.getResult(), newYield);
      }
  }
  noneElementalPartWasGenerated = true;
}

mlir::Value MaskedArrayExpr::generateShape(fir::FirOpBuilder &builder,
                                           mlir::IRMapping &mapper) {
  assert(noneElementalPartWasGenerated &&
         "non elemental part must have been generated");
  mlir::Operation &terminator = region.back().back();
  // If the operation that produced the yielded entity is elemental, it was not
  // cloned, but it holds a shape argument that was cloned. Return the cloned
  // shape.
  if (auto elementalAddrOp = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator))
    return mapper.lookupOrDefault(elementalAddrOp.getShape());
  mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
  if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
    return mapper.lookupOrDefault(elemental.getShape());
  // Otherwise, the whole entity was cloned, and the shape can be generated
  // from it.
  hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)};
  return hlfir::genShape(loc, builder, hlfir::Entity{clonedEntity});
}

mlir::Value
MaskedArrayExpr::generateElementalParts(fir::FirOpBuilder &builder,
                                        mlir::ValueRange oneBasedIndices,
                                        mlir::IRMapping &mapper) {
  assert(noneElementalPartWasGenerated &&
         "non elemental part must have been generated");
  if (!isOuterMaskExpr) {
    // Clone all operations that are not hlfir.exactly_once and that are not
    // hlfir.elemental/hlfir.elemental_addr.
    for (mlir::Operation &op : region.back().without_terminator())
      if (!mlir::isa<hlfir::ExactlyOnceOp>(op) && !elementalParts.contains(&op))
        (void)builder.clone(op, mapper);
    // For the outer mask, this was already done outside of the loop.
  }
  // Clone and "index" bodies of hlfir.elemental/hlfir.elemental_addr.
  mlir::Operation &terminator = region.back().back();
  hlfir::ElementalOpInterface elemental =
      mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator);
  if (!elemental) {
    // If the terminator is not an hlfir.elemental_addr, try if the yielded
    // entity was produced by an hlfir.elemental.
    mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
    elemental = entity.getDefiningOp<hlfir::ElementalOp>();
    if (!elemental) {
      // The yielded entity was not produced by an elemental operation,
      // get its clone in the non elemental part evaluation and address it.
      hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)};
      return hlfir::getElementAt(loc, builder, clonedEntity, oneBasedIndices);
    }
  }

  auto mustRecursivelyInline =
      [&](hlfir::ElementalOp appliedElemental) -> bool {
    return elementalParts.contains(appliedElemental.getOperation());
  };
  return inlineElementalOp(loc, builder, elemental, oneBasedIndices, mapper,
                           mustRecursivelyInline);
}

mlir::Operation *
MaskedArrayExpr::generateMaskedExprCleanUps(fir::FirOpBuilder &builder,
                                            mlir::IRMapping &mapper) {
  // Clone the clean-ups from the region itself, except for the destroy
  // of the hlfir.elemental that have been inlined.
  mlir::Operation &terminator = region.back().back();
  mlir::Region *cleanupRegion = nullptr;
  if (auto elementalAddr = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator)) {
    cleanupRegion = &elementalAddr.getCleanup();
  } else {
    auto yieldOp = mlir::cast<hlfir::YieldOp>(terminator);
    cleanupRegion = &yieldOp.getCleanup();
  }
  if (cleanupRegion->empty())
    return nullptr;
  mlir::Operation *firstNewCleanup = nullptr;
  for (mlir::Operation &op : cleanupRegion->front().without_terminator()) {
    if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(op))
      if (elementalParts.contains(destroy.getExpr().getDefiningOp()))
        continue;
    mlir::Operation *cleanup = builder.clone(op, mapper);
    if (!firstNewCleanup)
      firstNewCleanup = cleanup;
  }
  return firstNewCleanup;
}

void MaskedArrayExpr::generateNoneElementalCleanupIfAny(
    fir::FirOpBuilder &builder, mlir::IRMapping &mapper) {
  if (!isOuterMaskExpr) {
    // Clone clean-ups of hlfir.exactly_once operations (in reverse order
    // to properly deal with stack restores).
    for (mlir::Operation &op :
         llvm::reverse(region.back().without_terminator()))
      if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) {
        mlir::Region &cleanupRegion =
            getYield(exactlyOnce.getBody()).getCleanup();
        if (!cleanupRegion.empty())
          for (mlir::Operation &cleanupOp :
               cleanupRegion.front().without_terminator())
            (void)builder.clone(cleanupOp, mapper);
      }
  } else {
    // For the outer mask, the region clean-ups must be generated
    // outside of the loops since the mask non hlfir.elemental part
    // is generated before the loops.
    generateMaskedExprCleanUps(builder, mapper);
  }
}

static hlfir::RegionAssignOp
getAssignIfLeftHandSideRegion(mlir::Region &region) {
  auto assign = mlir::dyn_cast<hlfir::RegionAssignOp>(region.getParentOp());
  if (assign && (&assign.getLhsRegion() == &region))
    return assign;
  return nullptr;
}

bool OrderedAssignmentRewriter::currentLoopNestIterationNumberCanBeComputed(
    llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest) {
  if (constructStack.empty())
    return true;
  mlir::Operation *outerLoop = constructStack[0];
  mlir::Operation *currentConstruct = constructStack.back();
  // Loop through the loops until the outer construct is met, and test if the
  // loop operands dominate the outer construct.
  while (currentConstruct) {
    if (auto doLoop = mlir::dyn_cast<fir::DoLoopOp>(currentConstruct)) {
      if (llvm::any_of(doLoop->getOperands(), [&](mlir::Value value) {
            return !dominanceInfo.properlyDominates(value, outerLoop);
          })) {
        return false;
      }
      loopNest.push_back(doLoop);
    }
    if (currentConstruct == outerLoop)
      currentConstruct = nullptr;
    else
      currentConstruct = currentConstruct->getParentOp();
  }
  return true;
}

static mlir::Value
computeLoopNestIterationNumber(mlir::Location loc, fir::FirOpBuilder &builder,
                               llvm::ArrayRef<fir::DoLoopOp> loopNest) {
  mlir::Value loopExtent;
  for (fir::DoLoopOp doLoop : loopNest) {
    mlir::Value extent = builder.genExtentFromTriplet(
        loc, doLoop.getLowerBound(), doLoop.getUpperBound(), doLoop.getStep(),
        builder.getIndexType());
    if (!loopExtent)
      loopExtent = extent;
    else
      loopExtent = builder.create<mlir::arith::MulIOp>(loc, loopExtent, extent);
  }
  assert(loopExtent && "loopNest must not be empty");
  return loopExtent;
}

/// Return a name for temporary storage that indicates in which context
/// the temporary storage was created.
static llvm::StringRef
getTempName(hlfir::OrderedAssignmentTreeOpInterface root) {
  if (mlir::isa<hlfir::ForallOp>(root.getOperation()))
    return ".tmp.forall";
  if (mlir::isa<hlfir::WhereOp>(root.getOperation()))
    return ".tmp.where";
  return ".tmp.assign";
}

void OrderedAssignmentRewriter::generateSaveEntity(
    hlfir::SaveEntity savedEntity, bool willUseSavedEntityInSameRun) {
  mlir::Region &region = *savedEntity.yieldRegion;

  if (hlfir::RegionAssignOp regionAssignOp =
          getAssignIfLeftHandSideRegion(region)) {
    // Need to save the address, not the values.
    assert(!willUseSavedEntityInSameRun &&
           "lhs cannot be used in the loop nest where it is saved");
    return saveLeftHandSide(savedEntity, regionAssignOp);
  }

  mlir::Location loc = region.getParentOp()->getLoc();
  // Evaluate the region inside the loop nest (if any).
  auto [clonedValue, oldYield] = generateYieldedEntity(region);
  hlfir::Entity entity{clonedValue};
  entity = hlfir::loadTrivialScalar(loc, builder, entity);
  mlir::Type entityType = entity.getType();

  llvm::StringRef tempName = getTempName(root);
  fir::factory::TemporaryStorage *temp = nullptr;
  if (constructStack.empty()) {
    // Value evaluated outside of any loops (this may be the first MASK of a
    // WHERE construct, or an LHS/RHS temp of hlfir.region_assign outside of
    // WHERE/FORALL).
    temp = insertSavedEntity(
        region, fir::factory::SimpleCopy(loc, builder, entity, tempName));
  } else {
    // Need to create a temporary for values computed inside loops.
    // Create temporary storage outside of the loop nest given the entity
    // type (and the loop context).
    llvm::SmallVector<fir::DoLoopOp> loopNest;
    bool loopShapeCanBePreComputed =
        currentLoopNestIterationNumberCanBeComputed(loopNest);
    doBeforeLoopNest([&] {
      /// For simple scalars inside loops whose total iteration number can be
      /// pre-computed, create a rank-1 array outside of the loops. It will be
      /// assigned/fetched inside the loops like a normal Fortran array given
      /// the iteration count.
      if (loopShapeCanBePreComputed && fir::isa_trivial(entityType)) {
        mlir::Value loopExtent =
            computeLoopNestIterationNumber(loc, builder, loopNest);
        auto sequenceType =
            mlir::cast<fir::SequenceType>(builder.getVarLenSeqTy(entityType));
        temp = insertSavedEntity(region,
                                 fir::factory::HomogeneousScalarStack{
                                     loc, builder, sequenceType, loopExtent,
                                     /*lenParams=*/{}, allocateOnHeap,
                                     /*stackThroughLoops=*/true, tempName});

      } else {
        // If the number of iteration is not known, or if the values at each
        // iterations are values that may have different shape, type parameters
        // or dynamic type, use the runtime to create and manage a stack-like
        // temporary.
        temp = insertSavedEntity(
            region, fir::factory::AnyValueStack{loc, builder, entityType});
      }
    });
    // Inside the loop nest (and any fir.if if there are active masks), copy
    // the value to the temp and do clean-ups for the value if any.
    temp->pushValue(loc, builder, entity);
  }

  // Delay the clean-up if the entity will be used in the same run (i.e., the
  // parent construct will be visited and needs to be lowered). When possible,
  // this is not done for hlfir.expr because this use would prevent the
  // hlfir.expr storage from being moved when creating the temporary in
  // bufferization, and that would lead to an extra copy.
  if (willUseSavedEntityInSameRun &&
      (!temp->canBeFetchedAfterPush() ||
       !mlir::isa<hlfir::ExprType>(entity.getType()))) {
    auto inserted =
        savedInCurrentRunBeforeUse.try_emplace(&region, entity, oldYield);
    assert(inserted.second && "entity must have been emplaced");
    (void)inserted;
  } else {
    if (constructStack.empty() &&
        mlir::isa<hlfir::RegionAssignOp>(region.getParentOp())) {
      // Here the clean-up code is inserted after the original
      // RegionAssignOp, so that the assignment code happens
      // before the cleanup. We do this only for standalone
      // operations, because the clean-up is handled specially
      // during lowering of the parent constructs if any
      // (e.g. see generateNoneElementalCleanupIfAny for
      // WhereOp).
      auto insertionPoint = builder.saveInsertionPoint();
      builder.setInsertionPointAfter(region.getParentOp());
      generateCleanupIfAny(oldYield);
      builder.restoreInsertionPoint(insertionPoint);
    } else {
      generateCleanupIfAny(oldYield);
    }
  }
}

static bool rhsIsArray(hlfir::RegionAssignOp regionAssignOp) {
  auto yieldOp = mlir::dyn_cast<hlfir::YieldOp>(
      regionAssignOp.getRhsRegion().back().back());
  return yieldOp && hlfir::Entity{yieldOp.getEntity()}.isArray();
}

void OrderedAssignmentRewriter::saveLeftHandSide(
    hlfir::SaveEntity savedEntity, hlfir::RegionAssignOp regionAssignOp) {
  mlir::Region &region = *savedEntity.yieldRegion;
  mlir::Location loc = region.getParentOp()->getLoc();
  LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
  fir::factory::TemporaryStorage *temp = nullptr;
  if (loweredLhs.vectorSubscriptLoopNest)
    constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
  if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
    // Vector subscripted entity for which the shape must also be saved on top
    // of the element addresses (e.g. the shape may change in each forall
    // iteration and is needed to create the elemental loops).
    mlir::Value shape = loweredLhs.vectorSubscriptShape.value();
    int rank = mlir::cast<fir::ShapeType>(shape.getType()).getRank();
    const bool shapeIsInvariant =
        constructStack.empty() ||
        dominanceInfo.properlyDominates(shape, constructStack[0]);
    doBeforeLoopNest([&] {
      // Outside of any forall/where/elemental loops, create a temporary that
      // will both be able to save the vector subscripted designator shape(s)
      // and element addresses.
      temp =
          insertSavedEntity(region, fir::factory::AnyVectorSubscriptStack{
                                        loc, builder, loweredLhs.lhs.getType(),
                                        shapeIsInvariant, rank});
    });
    // Save shape before the elemental loop nest created by the vector
    // subscripted LHS.
    auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
    auto insertionPoint = builder.saveInsertionPoint();
    builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
    vectorTmp.pushShape(loc, builder, shape);
    builder.restoreInsertionPoint(insertionPoint);
  } else {
    // Otherwise, only save the LHS address.
    // If the LHS address dominates the constructs, its SSA value can
    // simply be tracked and there is no need to save the address in memory.
    // Otherwise, the addresses are stored at each iteration in memory with
    // a descriptor stack.
    if (constructStack.empty() ||
        dominanceInfo.properlyDominates(loweredLhs.lhs, constructStack[0]))
      doBeforeLoopNest([&] {
        temp = insertSavedEntity(region, fir::factory::SSARegister{});
      });
    else
      doBeforeLoopNest([&] {
        temp = insertSavedEntity(
            region, fir::factory::AnyVariableStack{loc, builder,
                                                   loweredLhs.lhs.getType()});
      });
  }
  temp->pushValue(loc, builder, loweredLhs.lhs);
  generateCleanupIfAny(loweredLhs.elementalCleanup);
  if (loweredLhs.vectorSubscriptLoopNest) {
    constructStack.pop_back();
    builder.setInsertionPointAfter(
        loweredLhs.vectorSubscriptLoopNest->outerLoop);
  }
}

/// Lower an ordered assignment tree to fir.do_loop and hlfir.assign given
/// a schedule.
static void lower(hlfir::OrderedAssignmentTreeOpInterface root,
                  mlir::PatternRewriter &rewriter, hlfir::Schedule &schedule) {
  auto module = root->getParentOfType<mlir::ModuleOp>();
  fir::FirOpBuilder builder(rewriter, module);
  OrderedAssignmentRewriter assignmentRewriter(builder, root);
  for (auto &run : schedule)
    assignmentRewriter.lowerRun(run);
  assignmentRewriter.cleanupSavedEntities();
}

/// Shared rewrite entry point for all the ordered assignment tree root
/// operations. It calls the scheduler and then apply the schedule.
static llvm::LogicalResult rewrite(hlfir::OrderedAssignmentTreeOpInterface root,
                                   bool tryFusingAssignments,
                                   mlir::PatternRewriter &rewriter) {
  hlfir::Schedule schedule =
      hlfir::buildEvaluationSchedule(root, tryFusingAssignments);

  LLVM_DEBUG(
      /// Debug option to print the scheduling debug info without doing
      /// any code generation. The operations are simply erased to avoid
      /// failing and calling the rewrite patterns on nested operations.
      /// The only purpose of this is to help testing scheduling without
      /// having to test generated code.
      if (dbgScheduleOnly) {
        rewriter.eraseOp(root);
        return mlir::success();
      });
  lower(root, rewriter, schedule);
  rewriter.eraseOp(root);
  return mlir::success();
}

namespace {

class ForallOpConversion : public mlir::OpRewritePattern<hlfir::ForallOp> {
public:
  explicit ForallOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments)
      : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {}

  llvm::LogicalResult
  matchAndRewrite(hlfir::ForallOp forallOp,
                  mlir::PatternRewriter &rewriter) const override {
    auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
        forallOp.getOperation());
    if (mlir::failed(::rewrite(root, tryFusingAssignments, rewriter)))
      TODO(forallOp.getLoc(), "FORALL construct or statement in HLFIR");
    return mlir::success();
  }
  const bool tryFusingAssignments;
};

class WhereOpConversion : public mlir::OpRewritePattern<hlfir::WhereOp> {
public:
  explicit WhereOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments)
      : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {}

  llvm::LogicalResult
  matchAndRewrite(hlfir::WhereOp whereOp,
                  mlir::PatternRewriter &rewriter) const override {
    auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
        whereOp.getOperation());
    return ::rewrite(root, tryFusingAssignments, rewriter);
  }
  const bool tryFusingAssignments;
};

class RegionAssignConversion
    : public mlir::OpRewritePattern<hlfir::RegionAssignOp> {
public:
  explicit RegionAssignConversion(mlir::MLIRContext *ctx)
      : OpRewritePattern{ctx} {}

  llvm::LogicalResult
  matchAndRewrite(hlfir::RegionAssignOp regionAssignOp,
                  mlir::PatternRewriter &rewriter) const override {
    auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
        regionAssignOp.getOperation());
    return ::rewrite(root, /*tryFusingAssignments=*/false, rewriter);
  }
};

class LowerHLFIROrderedAssignments
    : public hlfir::impl::LowerHLFIROrderedAssignmentsBase<
          LowerHLFIROrderedAssignments> {
public:
  using LowerHLFIROrderedAssignmentsBase<
      LowerHLFIROrderedAssignments>::LowerHLFIROrderedAssignmentsBase;

  void runOnOperation() override {
    // Running on a ModuleOp because this pass may generate FuncOp declaration
    // for runtime calls. This could be a FuncOp pass otherwise.
    auto module = this->getOperation();
    auto *context = &getContext();
    mlir::RewritePatternSet patterns(context);
    // Patterns are only defined for the OrderedAssignmentTreeOpInterface
    // operations that can be the root of ordered assignments. The other
    // operations will be taken care of while rewriting these trees (they
    // cannot exist outside of these operations given their verifiers/traits).
    patterns.insert<ForallOpConversion, WhereOpConversion>(
        context, this->tryFusingAssignments.getValue());
    patterns.insert<RegionAssignConversion>(context);
    mlir::ConversionTarget target(*context);
    target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) {
      return !mlir::isa<hlfir::OrderedAssignmentTreeOpInterface>(op);
    });
    if (mlir::failed(mlir::applyPartialConversion(module, target,
                                                  std::move(patterns)))) {
      mlir::emitError(mlir::UnknownLoc::get(context),
                      "failure in HLFIR ordered assignments lowering pass");
      signalPassFailure();
    }
  }
};
} // namespace