llvm/flang/lib/Semantics/rewrite-directives.cpp

//===-- lib/Semantics/rewrite-directives.cpp ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "rewrite-directives.h"
#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include <list>

namespace Fortran::semantics {

using namespace parser::literals;

class DirectiveRewriteMutator {
public:
  explicit DirectiveRewriteMutator(SemanticsContext &context)
      : context_{context} {}

  // Default action for a parse tree node is to visit children.
  template <typename T> bool Pre(T &) { return true; }
  template <typename T> void Post(T &) {}

protected:
  SemanticsContext &context_;
};

// Rewrite atomic constructs to add an explicit memory ordering to all that do
// not specify it, honoring in this way the `atomic_default_mem_order` clause of
// the REQUIRES directive.
class OmpRewriteMutator : public DirectiveRewriteMutator {
public:
  explicit OmpRewriteMutator(SemanticsContext &context)
      : DirectiveRewriteMutator(context) {}

  template <typename T> bool Pre(T &) { return true; }
  template <typename T> void Post(T &) {}

  bool Pre(parser::OpenMPAtomicConstruct &);
  bool Pre(parser::OpenMPRequiresConstruct &);

private:
  bool atomicDirectiveDefaultOrderFound_{false};
};

bool OmpRewriteMutator::Pre(parser::OpenMPAtomicConstruct &x) {
  // Find top-level parent of the operation.
  Symbol *topLevelParent{common::visit(
      [&](auto &atomic) {
        Symbol *symbol{nullptr};
        Scope *scope{
            &context_.FindScope(std::get<parser::Verbatim>(atomic.t).source)};
        do {
          if (Symbol * parent{scope->symbol()}) {
            symbol = parent;
          }
          scope = &scope->parent();
        } while (!scope->IsGlobal());

        assert(symbol &&
            "Atomic construct must be within a scope associated with a symbol");
        return symbol;
      },
      x.u)};

  // Get the `atomic_default_mem_order` clause from the top-level parent.
  std::optional<common::OmpAtomicDefaultMemOrderType> defaultMemOrder;
  common::visit(
      [&](auto &details) {
        if constexpr (std::is_convertible_v<decltype(&details),
                          WithOmpDeclarative *>) {
          if (details.has_ompAtomicDefaultMemOrder()) {
            defaultMemOrder = *details.ompAtomicDefaultMemOrder();
          }
        }
      },
      topLevelParent->details());

  if (!defaultMemOrder) {
    return false;
  }

  auto findMemOrderClause =
      [](const std::list<parser::OmpAtomicClause> &clauses) {
        return llvm::any_of(clauses, [](const auto &clause) {
          return std::get_if<parser::OmpMemoryOrderClause>(&clause.u);
        });
      };

  // Get the clause list to which the new memory order clause must be added,
  // only if there are no other memory order clauses present for this atomic
  // directive.
  std::list<parser::OmpAtomicClause> *clauseList = common::visit(
      common::visitors{[&](parser::OmpAtomic &atomicConstruct) {
                         // OmpAtomic only has a single list of clauses.
                         auto &clauses{std::get<parser::OmpAtomicClauseList>(
                             atomicConstruct.t)};
                         return !findMemOrderClause(clauses.v) ? &clauses.v
                                                               : nullptr;
                       },
          [&](auto &atomicConstruct) {
            // All other atomic constructs have two lists of clauses.
            auto &clausesLhs{std::get<0>(atomicConstruct.t)};
            auto &clausesRhs{std::get<2>(atomicConstruct.t)};
            return !findMemOrderClause(clausesLhs.v) &&
                    !findMemOrderClause(clausesRhs.v)
                ? &clausesRhs.v
                : nullptr;
          }},
      x.u);

  // Add a memory order clause to the atomic directive.
  if (clauseList) {
    atomicDirectiveDefaultOrderFound_ = true;
    switch (*defaultMemOrder) {
    case common::OmpAtomicDefaultMemOrderType::AcqRel:
      clauseList->emplace_back<parser::OmpMemoryOrderClause>(common::visit(
          common::visitors{[](parser::OmpAtomicRead &) -> parser::OmpClause {
                             return parser::OmpClause::Acquire{};
                           },
              [](parser::OmpAtomicCapture &) -> parser::OmpClause {
                return parser::OmpClause::AcqRel{};
              },
              [](auto &) -> parser::OmpClause {
                // parser::{OmpAtomic, OmpAtomicUpdate, OmpAtomicWrite}
                return parser::OmpClause::Release{};
              }},
          x.u));
      break;
    case common::OmpAtomicDefaultMemOrderType::Relaxed:
      clauseList->emplace_back<parser::OmpMemoryOrderClause>(
          parser::OmpClause{parser::OmpClause::Relaxed{}});
      break;
    case common::OmpAtomicDefaultMemOrderType::SeqCst:
      clauseList->emplace_back<parser::OmpMemoryOrderClause>(
          parser::OmpClause{parser::OmpClause::SeqCst{}});
      break;
    }
  }

  return false;
}

bool OmpRewriteMutator::Pre(parser::OpenMPRequiresConstruct &x) {
  for (parser::OmpClause &clause : std::get<parser::OmpClauseList>(x.t).v) {
    if (std::holds_alternative<parser::OmpClause::AtomicDefaultMemOrder>(
            clause.u) &&
        atomicDirectiveDefaultOrderFound_) {
      context_.Say(clause.source,
          "REQUIRES directive with '%s' clause found lexically after atomic "
          "operation without a memory order clause"_err_en_US,
          parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
              llvm::omp::OMPC_atomic_default_mem_order)
                                         .str()));
    }
  }
  return false;
}

bool RewriteOmpParts(SemanticsContext &context, parser::Program &program) {
  if (!context.IsEnabled(common::LanguageFeature::OpenMP)) {
    return true;
  }
  OmpRewriteMutator ompMutator{context};
  parser::Walk(program, ompMutator);
  return !context.AnyFatalError();
}

} // namespace Fortran::semantics