llvm/flang/include/flang/Lower/IterationSpace.h

//===-- IterationSpace.h ----------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_LOWER_ITERATIONSPACE_H
#define FORTRAN_LOWER_ITERATIONSPACE_H

#include "flang/Evaluate/tools.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include <optional>

namespace llvm {
class raw_ostream;
}

namespace Fortran {
namespace evaluate {
struct SomeType;
template <typename>
class Expr;
} // namespace evaluate

namespace lower {

using FrontEndExpr = const evaluate::Expr<evaluate::SomeType> *;
using FrontEndSymbol = const semantics::Symbol *;

class AbstractConverter;

} // namespace lower
} // namespace Fortran

namespace Fortran::lower {

/// Abstraction of the iteration space for building the elemental compute loop
/// of an array(-like) statement.
class IterationSpace {
public:
  IterationSpace() = default;

  template <typename A>
  explicit IterationSpace(mlir::Value inArg, mlir::Value outRes,
                          llvm::iterator_range<A> range)
      : inArg{inArg}, outRes{outRes}, indices{range.begin(), range.end()} {}

  explicit IterationSpace(const IterationSpace &from,
                          llvm::ArrayRef<mlir::Value> idxs)
      : inArg(from.inArg), outRes(from.outRes), element(from.element),
        indices(idxs) {}

  /// Create a copy of the \p from IterationSpace and prepend the \p prefix
  /// values and append the \p suffix values, respectively.
  explicit IterationSpace(const IterationSpace &from,
                          llvm::ArrayRef<mlir::Value> prefix,
                          llvm::ArrayRef<mlir::Value> suffix)
      : inArg(from.inArg), outRes(from.outRes), element(from.element) {
    indices.assign(prefix.begin(), prefix.end());
    indices.append(from.indices.begin(), from.indices.end());
    indices.append(suffix.begin(), suffix.end());
  }

  bool empty() const { return indices.empty(); }

  /// This is the output value as it appears as an argument in the innermost
  /// loop in the nest. The output value is threaded through the loop (and
  /// conditionals) to maintain proper SSA form.
  mlir::Value innerArgument() const { return inArg; }

  /// This is the output value as it appears as an output value from the
  /// outermost loop in the loop nest. The output value is threaded through the
  /// loop (and conditionals) to maintain proper SSA form.
  mlir::Value outerResult() const { return outRes; }

  /// Returns a vector for the iteration space. This vector is used to access
  /// elements of arrays in the compute loop.
  llvm::SmallVector<mlir::Value> iterVec() const { return indices; }

  mlir::Value iterValue(std::size_t i) const {
    assert(i < indices.size());
    return indices[i];
  }

  /// Set (rewrite) the Value at a given index.
  void setIndexValue(std::size_t i, mlir::Value v) {
    assert(i < indices.size());
    indices[i] = v;
  }

  void setIndexValues(llvm::ArrayRef<mlir::Value> vals) {
    indices.assign(vals.begin(), vals.end());
  }

  void insertIndexValue(std::size_t i, mlir::Value av) {
    assert(i <= indices.size());
    indices.insert(indices.begin() + i, av);
  }

  /// Set the `element` value. This is the SSA value that corresponds to an
  /// element of the resultant array value.
  void setElement(fir::ExtendedValue &&ele) {
    assert(!fir::getBase(element) && "result element already set");
    element = ele;
  }

  /// Get the value that will be merged into the resultant array. This is the
  /// computed value that will be stored to the lhs of the assignment.
  mlir::Value getElement() const {
    assert(fir::getBase(element) && "element must be set");
    return fir::getBase(element);
  }

  /// Get the element as an extended value.
  fir::ExtendedValue elementExv() const { return element; }

  void clearIndices() { indices.clear(); }

private:
  mlir::Value inArg;
  mlir::Value outRes;
  fir::ExtendedValue element;
  llvm::SmallVector<mlir::Value> indices;
};

using GenerateElementalArrayFunc =
    std::function<fir::ExtendedValue(const IterationSpace &)>;

template <typename A>
class StackableConstructExpr {
public:
  bool empty() const { return stack.empty(); }

  void growStack() { stack.push_back(A{}); }

  /// Bind a front-end expression to a closure.
  void bind(FrontEndExpr e, GenerateElementalArrayFunc &&fun) {
    vmap.insert({e, std::move(fun)});
  }

  /// Replace the binding of front-end expression `e` with a new closure.
  void rebind(FrontEndExpr e, GenerateElementalArrayFunc &&fun) {
    vmap.erase(e);
    bind(e, std::move(fun));
  }

  /// Get the closure bound to the front-end expression, `e`.
  GenerateElementalArrayFunc getBoundClosure(FrontEndExpr e) const {
    if (!vmap.count(e))
      llvm::report_fatal_error(
          "evaluate::Expr is not in the map of lowered mask expressions");
    return vmap.lookup(e);
  }

  /// Has the front-end expression, `e`, been lowered and bound?
  bool isLowered(FrontEndExpr e) const { return vmap.count(e); }

  StatementContext &stmtContext() { return stmtCtx; }

protected:
  void shrinkStack() {
    assert(!empty());
    stack.pop_back();
    if (empty()) {
      stmtCtx.finalizeAndReset();
      vmap.clear();
    }
  }

  // The stack for the construct information.
  llvm::SmallVector<A> stack;

  // Map each mask expression back to the temporary holding the initial
  // evaluation results.
  llvm::DenseMap<FrontEndExpr, GenerateElementalArrayFunc> vmap;

  // Inflate the statement context for the entire construct. We have to cache
  // the mask expression results, which are always evaluated first, across the
  // entire construct.
  StatementContext stmtCtx;
};

class ImplicitIterSpace;
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const ImplicitIterSpace &);

/// All array expressions have an implicit iteration space, which is isomorphic
/// to the shape of the base array that facilitates the expression having a
/// non-zero rank. This implied iteration space may be conditionalized
/// (disjunctively) with an if-elseif-else like structure, specifically
/// Fortran's WHERE construct.
///
/// This class is used in the bridge to collect the expressions from the
/// front end (the WHERE construct mask expressions), forward them for lowering
/// as array expressions in an "evaluate once" (copy-in, copy-out) semantics.
/// See 10.2.3.2p3, 10.2.3.2p13, etc.
class ImplicitIterSpace
    : public StackableConstructExpr<llvm::SmallVector<FrontEndExpr>> {
public:
  using Base = StackableConstructExpr<llvm::SmallVector<FrontEndExpr>>;
  using FrontEndMaskExpr = FrontEndExpr;

  friend llvm::raw_ostream &operator<<(llvm::raw_ostream &,
                                       const ImplicitIterSpace &);

  LLVM_DUMP_METHOD void dump() const;

  void append(FrontEndMaskExpr e) {
    assert(!empty());
    getMasks().back().push_back(e);
  }

  llvm::SmallVector<FrontEndMaskExpr> getExprs() const {
    llvm::SmallVector<FrontEndMaskExpr> maskList = getMasks()[0];
    for (size_t i = 1, d = getMasks().size(); i < d; ++i)
      maskList.append(getMasks()[i].begin(), getMasks()[i].end());
    return maskList;
  }

  /// Add a variable binding, `var`, along with its shape for the mask
  /// expression `exp`.
  void addMaskVariable(FrontEndExpr exp, mlir::Value var, mlir::Value shape,
                       mlir::Value header) {
    maskVarMap.try_emplace(exp, std::make_tuple(var, shape, header));
  }

  /// Lookup the variable corresponding to the temporary buffer that contains
  /// the mask array expression results.
  mlir::Value lookupMaskVariable(FrontEndExpr exp) {
    return std::get<0>(maskVarMap.lookup(exp));
  }

  /// Lookup the variable containing the shape vector for the mask array
  /// expression results.
  mlir::Value lookupMaskShapeBuffer(FrontEndExpr exp) {
    return std::get<1>(maskVarMap.lookup(exp));
  }

  mlir::Value lookupMaskHeader(FrontEndExpr exp) {
    return std::get<2>(maskVarMap.lookup(exp));
  }

  // Stack of WHERE constructs, each building a list of mask expressions.
  llvm::SmallVector<llvm::SmallVector<FrontEndMaskExpr>> &getMasks() {
    return stack;
  }
  const llvm::SmallVector<llvm::SmallVector<FrontEndMaskExpr>> &
  getMasks() const {
    return stack;
  }

  // Cleanup at the end of a WHERE statement or construct.
  void shrinkStack() {
    Base::shrinkStack();
    if (stack.empty())
      maskVarMap.clear();
  }

private:
  llvm::DenseMap<FrontEndExpr,
                 std::tuple<mlir::Value, mlir::Value, mlir::Value>>
      maskVarMap;
};

class ExplicitIterSpace;
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const ExplicitIterSpace &);

/// Create all the array_load ops for the explicit iteration space context. The
/// nest of FORALLs must have been analyzed a priori.
void createArrayLoads(AbstractConverter &converter, ExplicitIterSpace &esp,
                      SymMap &symMap);

/// Create the array_merge_store ops after the explicit iteration space context
/// is conmpleted.
void createArrayMergeStores(AbstractConverter &converter,
                            ExplicitIterSpace &esp);
using ExplicitSpaceArrayBases =
    std::variant<FrontEndSymbol, const evaluate::Component *,
                 const evaluate::ArrayRef *>;

unsigned getHashValue(const ExplicitSpaceArrayBases &x);
bool isEqual(const ExplicitSpaceArrayBases &x,
             const ExplicitSpaceArrayBases &y);

} // namespace Fortran::lower

namespace llvm {
template <>
struct DenseMapInfo<Fortran::lower::ExplicitSpaceArrayBases> {
  static inline Fortran::lower::ExplicitSpaceArrayBases getEmptyKey() {
    return reinterpret_cast<Fortran::lower::FrontEndSymbol>(~0);
  }
  static inline Fortran::lower::ExplicitSpaceArrayBases getTombstoneKey() {
    return reinterpret_cast<Fortran::lower::FrontEndSymbol>(~0 - 1);
  }
  static unsigned
  getHashValue(const Fortran::lower::ExplicitSpaceArrayBases &v) {
    return Fortran::lower::getHashValue(v);
  }
  static bool isEqual(const Fortran::lower::ExplicitSpaceArrayBases &lhs,
                      const Fortran::lower::ExplicitSpaceArrayBases &rhs) {
    return Fortran::lower::isEqual(lhs, rhs);
  }
};
} // namespace llvm

namespace Fortran::lower {
/// Fortran also allows arrays to be evaluated under constructs which allow the
/// user to explicitly specify the iteration space using concurrent-control
/// expressions. These constructs allow the user to define both an iteration
/// space and explicit access vectors on arrays. These need not be isomorphic.
/// The explicit iteration spaces may be conditionalized (conjunctively) with an
/// "and" structure and may be found in FORALL (and DO CONCURRENT) constructs.
///
/// This class is used in the bridge to collect a stack of lists of
/// concurrent-control expressions to be used to generate the iteration space
/// and associated masks (if any) for a set of nested FORALL constructs around
/// assignment and WHERE constructs.
class ExplicitIterSpace {
public:
  using IterSpaceDim =
      std::tuple<FrontEndSymbol, FrontEndExpr, FrontEndExpr, FrontEndExpr>;
  using ConcurrentSpec =
      std::pair<llvm::SmallVector<IterSpaceDim>, FrontEndExpr>;
  using ArrayBases = ExplicitSpaceArrayBases;

  friend void createArrayLoads(AbstractConverter &converter,
                               ExplicitIterSpace &esp, SymMap &symMap);
  friend void createArrayMergeStores(AbstractConverter &converter,
                                     ExplicitIterSpace &esp);

  /// Is a FORALL context presently active?
  /// If we are lowering constructs/statements nested within a FORALL, then a
  /// FORALL context is active.
  bool isActive() const { return forallContextOpen != 0; }

  /// Get the statement context.
  StatementContext &stmtContext() { return stmtCtx; }

  //===--------------------------------------------------------------------===//
  // Analysis support
  //===--------------------------------------------------------------------===//

  /// Open a new construct. The analysis phase starts here.
  void pushLevel();

  /// Close the construct.
  void popLevel();

  /// Add new concurrent header control variable symbol.
  void addSymbol(FrontEndSymbol sym);

  /// Collect array bases from the expression, `x`.
  void exprBase(FrontEndExpr x, bool lhs);

  /// Called at the end of a assignment statement.
  void endAssign();

  /// Return all the active control variables on the stack.
  llvm::SmallVector<FrontEndSymbol> collectAllSymbols();

  //===--------------------------------------------------------------------===//
  // Code gen support
  //===--------------------------------------------------------------------===//

  /// Enter a FORALL context.
  void enter() { forallContextOpen++; }

  /// Leave a FORALL context.
  void leave();

  void pushLoopNest(std::function<void()> lambda) {
    ccLoopNest.push_back(lambda);
  }

  /// Get the inner arguments that correspond to the output arrays.
  mlir::ValueRange getInnerArgs() const { return innerArgs; }

  /// Set the inner arguments for the next loop level.
  void setInnerArgs(llvm::ArrayRef<mlir::BlockArgument> args) {
    innerArgs.clear();
    for (auto &arg : args)
      innerArgs.push_back(arg);
  }

  /// Reset the outermost `array_load` arguments to the loop nest.
  void resetInnerArgs() { innerArgs = initialArgs; }

  /// Capture the current outermost loop.
  void setOuterLoop(fir::DoLoopOp loop) {
    clearLoops();
    outerLoop = loop;
  }

  /// Sets the inner loop argument at position \p offset to \p val.
  void setInnerArg(size_t offset, mlir::Value val) {
    assert(offset < innerArgs.size());
    innerArgs[offset] = val;
  }

  /// Get the types of the output arrays.
  llvm::SmallVector<mlir::Type> innerArgTypes() const {
    llvm::SmallVector<mlir::Type> result;
    for (auto &arg : innerArgs)
      result.push_back(arg.getType());
    return result;
  }

  /// Create a binding between an Ev::Expr node pointer and a fir::array_load
  /// op. This bindings will be used when generating the IR.
  void bindLoad(ArrayBases base, fir::ArrayLoadOp load) {
    loadBindings.try_emplace(std::move(base), load);
  }

  fir::ArrayLoadOp findBinding(const ArrayBases &base) {
    return loadBindings.lookup(base);
  }

  /// `load` must be a LHS array_load. Returns `std::nullopt` on error.
  std::optional<size_t> findArgPosition(fir::ArrayLoadOp load);

  bool isLHS(fir::ArrayLoadOp load) {
    return findArgPosition(load).has_value();
  }

  /// `load` must be a LHS array_load. Determine the threaded inner argument
  /// corresponding to this load.
  mlir::Value findArgumentOfLoad(fir::ArrayLoadOp load) {
    if (auto opt = findArgPosition(load))
      return innerArgs[*opt];
    llvm_unreachable("array load argument not found");
  }

  size_t argPosition(mlir::Value arg) {
    for (auto i : llvm::enumerate(innerArgs))
      if (arg == i.value())
        return i.index();
    llvm_unreachable("inner argument value was not found");
  }

  std::optional<fir::ArrayLoadOp> getLhsLoad(size_t i) {
    assert(i < lhsBases.size());
    if (lhsBases[counter])
      return findBinding(*lhsBases[counter]);
    return std::nullopt;
  }

  /// Return the outermost loop in this FORALL nest.
  fir::DoLoopOp getOuterLoop() {
    assert(outerLoop.has_value());
    return *outerLoop;
  }

  /// Return the statement context for the entire, outermost FORALL construct.
  StatementContext &outermostContext() { return outerContext; }

  /// Generate the explicit loop nest.
  void genLoopNest() {
    for (auto &lambda : ccLoopNest)
      lambda();
  }

  /// Clear the array_load bindings.
  void resetBindings() { loadBindings.clear(); }

  /// Get the current counter value.
  std::size_t getCounter() const { return counter; }

  /// Increment the counter value to the next assignment statement.
  void incrementCounter() { counter++; }

  bool isOutermostForall() const {
    assert(forallContextOpen);
    return forallContextOpen == 1;
  }

  void attachLoopCleanup(std::function<void(fir::FirOpBuilder &builder)> fn) {
    if (!loopCleanup) {
      loopCleanup = fn;
      return;
    }
    std::function<void(fir::FirOpBuilder &)> oldFn = *loopCleanup;
    loopCleanup = [=](fir::FirOpBuilder &builder) {
      oldFn(builder);
      fn(builder);
    };
  }

  // LLVM standard dump method.
  LLVM_DUMP_METHOD void dump() const;

  // Pretty-print.
  friend llvm::raw_ostream &operator<<(llvm::raw_ostream &,
                                       const ExplicitIterSpace &);

  /// Finalize the current body statement context.
  void finalizeContext() { stmtCtx.finalizeAndReset(); }

  void appendLoops(const llvm::SmallVector<fir::DoLoopOp> &loops) {
    loopStack.push_back(loops);
  }

  void clearLoops() { loopStack.clear(); }

  llvm::SmallVector<llvm::SmallVector<fir::DoLoopOp>> getLoopStack() const {
    return loopStack;
  }

private:
  /// Cleanup the analysis results.
  void conditionalCleanup();

  StatementContext outerContext;

  // A stack of lists of front-end symbols.
  llvm::SmallVector<llvm::SmallVector<FrontEndSymbol>> symbolStack;
  llvm::SmallVector<std::optional<ArrayBases>> lhsBases;
  llvm::SmallVector<llvm::SmallVector<ArrayBases>> rhsBases;
  llvm::DenseMap<ArrayBases, fir::ArrayLoadOp> loadBindings;

  // Stack of lambdas to create the loop nest.
  llvm::SmallVector<std::function<void()>> ccLoopNest;

  // Assignment statement context (inside the loop nest).
  StatementContext stmtCtx;
  llvm::SmallVector<mlir::Value> innerArgs;
  llvm::SmallVector<mlir::Value> initialArgs;
  std::optional<fir::DoLoopOp> outerLoop;
  llvm::SmallVector<llvm::SmallVector<fir::DoLoopOp>> loopStack;
  std::optional<std::function<void(fir::FirOpBuilder &)>> loopCleanup;
  std::size_t forallContextOpen = 0;
  std::size_t counter = 0;
};

/// Is there a Symbol in common between the concurrent header set and the set
/// of symbols in the expression?
template <typename A>
bool symbolSetsIntersect(llvm::ArrayRef<FrontEndSymbol> ctrlSet,
                         const A &exprSyms) {
  for (const auto &sym : exprSyms)
    if (llvm::is_contained(ctrlSet, &sym.get()))
      return true;
  return false;
}

/// Determine if the subscript expression symbols from an Ev::ArrayRef
/// intersects with the set of concurrent control symbols, `ctrlSet`.
template <typename A>
bool symbolsIntersectSubscripts(llvm::ArrayRef<FrontEndSymbol> ctrlSet,
                                const A &subscripts) {
  for (auto &sub : subscripts) {
    if (const auto *expr =
            std::get_if<evaluate::IndirectSubscriptIntegerExpr>(&sub.u))
      if (symbolSetsIntersect(ctrlSet, evaluate::CollectSymbols(expr->value())))
        return true;
  }
  return false;
}

} // namespace Fortran::lower

#endif // FORTRAN_LOWER_ITERATIONSPACE_H