llvm/flang/lib/Semantics/check-do-forall.cpp

//===-- lib/Semantics/check-do-forall.cpp ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "check-do-forall.h"
#include "definable.h"
#include "flang/Common/template.h"
#include "flang/Evaluate/call.h"
#include "flang/Evaluate/expression.h"
#include "flang/Evaluate/tools.h"
#include "flang/Evaluate/traverse.h"
#include "flang/Parser/message.h"
#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/tools.h"
#include "flang/Semantics/attr.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "flang/Semantics/type.h"

namespace Fortran::evaluate {
using ActualArgumentRef = common::Reference<const ActualArgument>;

inline bool operator<(ActualArgumentRef x, ActualArgumentRef y) {
  return &*x < &*y;
}
} // namespace Fortran::evaluate

namespace Fortran::semantics {

using namespace parser::literals;

using Bounds = parser::LoopControl::Bounds;
using IndexVarKind = SemanticsContext::IndexVarKind;

static const parser::ConcurrentHeader &GetConcurrentHeader(
    const parser::LoopControl &loopControl) {
  const auto &concurrent{
      std::get<parser::LoopControl::Concurrent>(loopControl.u)};
  return std::get<parser::ConcurrentHeader>(concurrent.t);
}
static const parser::ConcurrentHeader &GetConcurrentHeader(
    const parser::ForallConstruct &construct) {
  const auto &stmt{
      std::get<parser::Statement<parser::ForallConstructStmt>>(construct.t)};
  return std::get<common::Indirection<parser::ConcurrentHeader>>(
      stmt.statement.t)
      .value();
}
static const parser::ConcurrentHeader &GetConcurrentHeader(
    const parser::ForallStmt &stmt) {
  return std::get<common::Indirection<parser::ConcurrentHeader>>(stmt.t)
      .value();
}
template <typename T>
static const std::list<parser::ConcurrentControl> &GetControls(const T &x) {
  return std::get<std::list<parser::ConcurrentControl>>(
      GetConcurrentHeader(x).t);
}

static const Bounds &GetBounds(const parser::DoConstruct &doConstruct) {
  auto &loopControl{doConstruct.GetLoopControl().value()};
  return std::get<Bounds>(loopControl.u);
}

static const parser::Name &GetDoVariable(
    const parser::DoConstruct &doConstruct) {
  const Bounds &bounds{GetBounds(doConstruct)};
  return bounds.name.thing;
}

static parser::MessageFixedText GetEnclosingDoMsg() {
  return "Enclosing DO CONCURRENT statement"_en_US;
}

static void SayWithDo(SemanticsContext &context, parser::CharBlock stmtLocation,
    parser::MessageFixedText &&message, parser::CharBlock doLocation) {
  context.Say(stmtLocation, message).Attach(doLocation, GetEnclosingDoMsg());
}

// 11.1.7.5 - enforce semantics constraints on a DO CONCURRENT loop body
class DoConcurrentBodyEnforce {
public:
  DoConcurrentBodyEnforce(
      SemanticsContext &context, parser::CharBlock doConcurrentSourcePosition)
      : context_{context},
        doConcurrentSourcePosition_{doConcurrentSourcePosition} {}
  std::set<parser::Label> labels() { return labels_; }
  template <typename T> bool Pre(const T &x) {
    if (const auto *expr{GetExpr(context_, x)}) {
      if (auto bad{FindImpureCall(context_.foldingContext(), *expr)}) {
        context_.Say(currentStatementSourcePosition_,
            "Impure procedure '%s' may not be referenced in DO CONCURRENT"_err_en_US,
            *bad);
      }
    }
    return true;
  }
  template <typename T> bool Pre(const parser::Statement<T> &statement) {
    currentStatementSourcePosition_ = statement.source;
    if (statement.label.has_value()) {
      labels_.insert(*statement.label);
    }
    return true;
  }
  template <typename T> bool Pre(const parser::UnlabeledStatement<T> &stmt) {
    currentStatementSourcePosition_ = stmt.source;
    return true;
  }
  bool Pre(const parser::CallStmt &x) {
    if (x.typedCall.get()) {
      if (auto bad{FindImpureCall(context_.foldingContext(), *x.typedCall)}) {
        context_.Say(currentStatementSourcePosition_,
            "Impure procedure '%s' may not be referenced in DO CONCURRENT"_err_en_US,
            *bad);
      }
    }
    return true;
  }
  template <typename T> void Post(const T &) {}

  // C1140 -- Can't deallocate a polymorphic entity in a DO CONCURRENT.
  // Deallocation can be caused by exiting a block that declares an allocatable
  // entity, assignment to an allocatable variable, or an actual DEALLOCATE
  // statement
  //
  // Note also that the deallocation of a derived type entity might cause the
  // invocation of an IMPURE final subroutine. (C1139)
  //

  // Predicate for deallocations caused by block exit and direct deallocation
  static bool DeallocateAll(const Symbol &) { return true; }

  // Predicate for deallocations caused by intrinsic assignment
  static bool DeallocateNonCoarray(const Symbol &component) {
    return !evaluate::IsCoarray(component);
  }

  static bool WillDeallocatePolymorphic(const Symbol &entity,
      const std::function<bool(const Symbol &)> &WillDeallocate) {
    return WillDeallocate(entity) && IsPolymorphicAllocatable(entity);
  }

  // Is it possible that we will we deallocate a polymorphic entity or one
  // of its components?
  static bool MightDeallocatePolymorphic(const Symbol &original,
      const std::function<bool(const Symbol &)> &WillDeallocate) {
    const Symbol &symbol{ResolveAssociations(original)};
    // Check the entity itself, no coarray exception here
    if (IsPolymorphicAllocatable(symbol)) {
      return true;
    }
    // Check the components
    if (const auto *details{symbol.detailsIf<ObjectEntityDetails>()}) {
      if (const DeclTypeSpec * entityType{details->type()}) {
        if (const DerivedTypeSpec * derivedType{entityType->AsDerived()}) {
          UltimateComponentIterator ultimates{*derivedType};
          for (const auto &ultimate : ultimates) {
            if (WillDeallocatePolymorphic(ultimate, WillDeallocate)) {
              return true;
            }
          }
        }
      }
    }
    return false;
  }

  void SayDeallocateWithImpureFinal(
      const Symbol &entity, const char *reason, const Symbol &impure) {
    context_.SayWithDecl(entity, currentStatementSourcePosition_,
        "Deallocation of an entity with an IMPURE FINAL procedure '%s' caused by %s not allowed in DO CONCURRENT"_err_en_US,
        impure.name(), reason);
  }

  void SayDeallocateOfPolymorph(
      parser::CharBlock location, const Symbol &entity, const char *reason) {
    context_.SayWithDecl(entity, location,
        "Deallocation of a polymorphic entity caused by %s"
        " not allowed in DO CONCURRENT"_err_en_US,
        reason);
  }

  // Deallocation caused by block exit
  // Allocatable entities and all of their allocatable subcomponents will be
  // deallocated.  This test is different from the other two because it does
  // not deallocate in cases where the entity itself is not allocatable but
  // has allocatable polymorphic components
  void Post(const parser::BlockConstruct &blockConstruct) {
    const auto &endBlockStmt{
        std::get<parser::Statement<parser::EndBlockStmt>>(blockConstruct.t)};
    const Scope &blockScope{context_.FindScope(endBlockStmt.source)};
    const Scope &doScope{context_.FindScope(doConcurrentSourcePosition_)};
    if (DoesScopeContain(&doScope, blockScope)) {
      const char *reason{"block exit"};
      for (auto &pair : blockScope) {
        const Symbol &entity{*pair.second};
        if (IsAllocatable(entity) && !IsSaved(entity) &&
            MightDeallocatePolymorphic(entity, DeallocateAll)) {
          SayDeallocateOfPolymorph(endBlockStmt.source, entity, reason);
        }
        if (const Symbol * impure{HasImpureFinal(entity)}) {
          SayDeallocateWithImpureFinal(entity, reason, *impure);
        }
      }
    }
  }

  // Deallocation caused by assignment
  // Note that this case does not cause deallocation of coarray components
  void Post(const parser::AssignmentStmt &stmt) {
    const auto &variable{std::get<parser::Variable>(stmt.t)};
    if (const Symbol * entity{GetLastName(variable).symbol}) {
      const char *reason{"assignment"};
      if (MightDeallocatePolymorphic(*entity, DeallocateNonCoarray)) {
        SayDeallocateOfPolymorph(variable.GetSource(), *entity, reason);
      }
      if (const auto *assignment{GetAssignment(stmt)}) {
        const auto &lhs{assignment->lhs};
        if (const Symbol * impure{HasImpureFinal(*entity, lhs.Rank())}) {
          SayDeallocateWithImpureFinal(*entity, reason, *impure);
        }
      }
    }
    if (const auto *assignment{GetAssignment(stmt)}) {
      if (const auto *call{
              std::get_if<evaluate::ProcedureRef>(&assignment->u)}) {
        if (auto bad{FindImpureCall(context_.foldingContext(), *call)}) {
          context_.Say(currentStatementSourcePosition_,
              "The defined assignment subroutine '%s' is not pure"_err_en_US,
              *bad);
        }
      }
    }
  }

  // Deallocation from a DEALLOCATE statement
  // This case is different because DEALLOCATE statements deallocate both
  // ALLOCATABLE and POINTER entities
  void Post(const parser::DeallocateStmt &stmt) {
    const auto &allocateObjectList{
        std::get<std::list<parser::AllocateObject>>(stmt.t)};
    for (const auto &allocateObject : allocateObjectList) {
      const parser::Name &name{GetLastName(allocateObject)};
      const char *reason{"a DEALLOCATE statement"};
      if (name.symbol) {
        const Symbol &entity{*name.symbol};
        const DeclTypeSpec *entityType{entity.GetType()};
        if ((entityType && entityType->IsPolymorphic()) || // POINTER case
            MightDeallocatePolymorphic(entity, DeallocateAll)) {
          SayDeallocateOfPolymorph(
              currentStatementSourcePosition_, entity, reason);
        }
        if (const Symbol * impure{HasImpureFinal(entity)}) {
          SayDeallocateWithImpureFinal(entity, reason, *impure);
        }
      }
    }
  }

  // C1137 -- No image control statements in a DO CONCURRENT
  void Post(const parser::ExecutableConstruct &construct) {
    if (IsImageControlStmt(construct)) {
      const parser::CharBlock statementLocation{
          GetImageControlStmtLocation(construct)};
      auto &msg{context_.Say(statementLocation,
          "An image control statement is not allowed in DO CONCURRENT"_err_en_US)};
      if (auto coarrayMsg{GetImageControlStmtCoarrayMsg(construct)}) {
        msg.Attach(statementLocation, *coarrayMsg);
      }
      msg.Attach(doConcurrentSourcePosition_, GetEnclosingDoMsg());
    }
  }

  // C1136 -- No RETURN statements in a DO CONCURRENT
  void Post(const parser::ReturnStmt &) {
    context_
        .Say(currentStatementSourcePosition_,
            "RETURN is not allowed in DO CONCURRENT"_err_en_US)
        .Attach(doConcurrentSourcePosition_, GetEnclosingDoMsg());
  }

  // C1145, C1146: cannot call ieee_[gs]et_flag, ieee_[gs]et_halting_mode,
  // ieee_[gs]et_status, ieee_set_rounding_mode, or ieee_set_underflow_mode
  void Post(const parser::ProcedureDesignator &procedureDesignator) {
    if (auto *name{std::get_if<parser::Name>(&procedureDesignator.u)}) {
      if (name->symbol) {
        const Symbol &ultimate{name->symbol->GetUltimate()};
        const Scope &scope{ultimate.owner()};
        if (const Symbol * module{scope.IsModule() ? scope.symbol() : nullptr};
            module &&
            (module->name() == "__fortran_ieee_arithmetic" ||
                module->name() == "__fortran_ieee_exceptions")) {
          std::string s{ultimate.name().ToString()};
          static constexpr const char *badName[]{"ieee_get_flag",
              "ieee_set_flag", "ieee_get_halting_mode", "ieee_set_halting_mode",
              "ieee_get_status", "ieee_set_status", "ieee_set_rounding_mode",
              "ieee_set_underflow_mode", nullptr};
          for (std::size_t j{0}; badName[j]; ++j) {
            if (s.find(badName[j]) != s.npos) {
              context_
                  .Say(name->source,
                      "'%s' may not be called in DO CONCURRENT"_err_en_US,
                      badName[j])
                  .Attach(doConcurrentSourcePosition_, GetEnclosingDoMsg());
              break;
            }
          }
        }
      }
    }
  }

  // 11.1.7.5, paragraph 5, no ADVANCE specifier in a DO CONCURRENT
  void Post(const parser::IoControlSpec &ioControlSpec) {
    if (auto *charExpr{
            std::get_if<parser::IoControlSpec::CharExpr>(&ioControlSpec.u)}) {
      if (std::get<parser::IoControlSpec::CharExpr::Kind>(charExpr->t) ==
          parser::IoControlSpec::CharExpr::Kind::Advance) {
        SayWithDo(context_, currentStatementSourcePosition_,
            "ADVANCE specifier is not allowed in DO"
            " CONCURRENT"_err_en_US,
            doConcurrentSourcePosition_);
      }
    }
  }

private:
  std::set<parser::Label> labels_;
  parser::CharBlock currentStatementSourcePosition_;
  SemanticsContext &context_;
  parser::CharBlock doConcurrentSourcePosition_;
}; // class DoConcurrentBodyEnforce

// Class for enforcing C1130 -- in a DO CONCURRENT with DEFAULT(NONE),
// variables from enclosing scopes must have their locality specified
class DoConcurrentVariableEnforce {
public:
  DoConcurrentVariableEnforce(
      SemanticsContext &context, parser::CharBlock doConcurrentSourcePosition)
      : context_{context},
        doConcurrentSourcePosition_{doConcurrentSourcePosition},
        blockScope_{context.FindScope(doConcurrentSourcePosition_)} {}

  template <typename T> bool Pre(const T &) { return true; }
  template <typename T> void Post(const T &) {}

  // Check to see if the name is a variable from an enclosing scope
  void Post(const parser::Name &name) {
    if (const Symbol * symbol{name.symbol}) {
      if (IsVariableName(*symbol)) {
        const Scope &variableScope{symbol->owner()};
        if (DoesScopeContain(&variableScope, blockScope_)) {
          context_.SayWithDecl(*symbol, name.source,
              "Variable '%s' from an enclosing scope referenced in DO "
              "CONCURRENT with DEFAULT(NONE) must appear in a "
              "locality-spec"_err_en_US,
              symbol->name());
        }
      }
    }
  }

private:
  SemanticsContext &context_;
  parser::CharBlock doConcurrentSourcePosition_;
  const Scope &blockScope_;
}; // class DoConcurrentVariableEnforce

// Find a DO or FORALL and enforce semantics checks on its body
class DoContext {
public:
  DoContext(SemanticsContext &context, IndexVarKind kind, bool isNested)
      : context_{context}, kind_{kind}, isNested_{isNested} {}

  // Mark this DO construct as a point of definition for the DO variables
  // or index-names it contains.  If they're already defined, emit an error
  // message.  We need to remember both the variable and the source location of
  // the variable in the DO construct so that we can remove it when we leave
  // the DO construct and use its location in error messages.
  void DefineDoVariables(const parser::DoConstruct &doConstruct) {
    if (doConstruct.IsDoNormal()) {
      context_.ActivateIndexVar(GetDoVariable(doConstruct), IndexVarKind::DO);
    } else if (doConstruct.IsDoConcurrent()) {
      if (const auto &loopControl{doConstruct.GetLoopControl()}) {
        ActivateIndexVars(GetControls(*loopControl));
      }
    }
  }

  // Called at the end of a DO construct to deactivate the DO construct
  void ResetDoVariables(const parser::DoConstruct &doConstruct) {
    if (doConstruct.IsDoNormal()) {
      context_.DeactivateIndexVar(GetDoVariable(doConstruct));
    } else if (doConstruct.IsDoConcurrent()) {
      if (const auto &loopControl{doConstruct.GetLoopControl()}) {
        DeactivateIndexVars(GetControls(*loopControl));
      }
    }
  }

  void ActivateIndexVars(const std::list<parser::ConcurrentControl> &controls) {
    for (const auto &control : controls) {
      context_.ActivateIndexVar(std::get<parser::Name>(control.t), kind_);
    }
  }
  void DeactivateIndexVars(
      const std::list<parser::ConcurrentControl> &controls) {
    for (const auto &control : controls) {
      context_.DeactivateIndexVar(std::get<parser::Name>(control.t));
    }
  }

  void Check(const parser::DoConstruct &doConstruct) {
    if (doConstruct.IsDoConcurrent()) {
      CheckDoConcurrent(doConstruct);
    } else if (doConstruct.IsDoNormal()) {
      CheckDoNormal(doConstruct);
    } else {
      // TODO: handle the other cases
    }
  }

  void Check(const parser::ForallStmt &stmt) {
    CheckConcurrentHeader(GetConcurrentHeader(stmt));
  }
  void Check(const parser::ForallConstruct &construct) {
    CheckConcurrentHeader(GetConcurrentHeader(construct));
  }

  void Check(const parser::ForallAssignmentStmt &stmt) {
    if (const evaluate::Assignment *
        assignment{common::visit(
            common::visitors{[&](const auto &x) { return GetAssignment(x); }},
            stmt.u)}) {
      CheckForallIndexesUsed(*assignment);
      CheckForImpureCall(assignment->lhs);
      CheckForImpureCall(assignment->rhs);

      if (IsVariable(assignment->lhs)) {
        if (const Symbol * symbol{GetLastSymbol(assignment->lhs)}) {
          if (auto impureFinal{
                  HasImpureFinal(*symbol, assignment->lhs.Rank())}) {
            context_.SayWithDecl(*symbol, parser::FindSourceLocation(stmt),
                "Impure procedure '%s' is referenced by finalization in a %s"_err_en_US,
                impureFinal->name(), LoopKindName());
          }
        }
      }

      if (const auto *proc{
              std::get_if<evaluate::ProcedureRef>(&assignment->u)}) {
        CheckForImpureCall(*proc);
      }
      common::visit(
          common::visitors{
              [](const evaluate::Assignment::Intrinsic &) {},
              [&](const evaluate::ProcedureRef &proc) {
                CheckForImpureCall(proc);
              },
              [&](const evaluate::Assignment::BoundsSpec &bounds) {
                for (const auto &bound : bounds) {
                  CheckForImpureCall(SomeExpr{bound});
                }
              },
              [&](const evaluate::Assignment::BoundsRemapping &bounds) {
                for (const auto &bound : bounds) {
                  CheckForImpureCall(SomeExpr{bound.first});
                  CheckForImpureCall(SomeExpr{bound.second});
                }
              },
          },
          assignment->u);
    }
  }

private:
  void SayBadDoControl(parser::CharBlock sourceLocation) {
    context_.Say(sourceLocation, "DO controls should be INTEGER"_err_en_US);
  }

  void CheckDoControl(const parser::CharBlock &sourceLocation, bool isReal) {
    if (isReal) {
      if (context_.ShouldWarn(common::LanguageFeature::RealDoControls)) {
        context_.Say(
            sourceLocation, "DO controls should be INTEGER"_port_en_US);
      }
    } else {
      SayBadDoControl(sourceLocation);
    }
  }

  void CheckDoVariable(const parser::ScalarName &scalarName) {
    const parser::CharBlock &sourceLocation{scalarName.thing.source};
    if (const Symbol * symbol{scalarName.thing.symbol}) {
      if (!IsVariableName(*symbol)) {
        context_.Say(
            sourceLocation, "DO control must be an INTEGER variable"_err_en_US);
      } else if (auto why{WhyNotDefinable(sourceLocation,
                     context_.FindScope(sourceLocation), DefinabilityFlags{},
                     *symbol)}) {
        context_
            .Say(sourceLocation,
                "'%s' may not be used as a DO variable"_err_en_US,
                symbol->name())
            .Attach(std::move(why->set_severity(parser::Severity::Because)));
      } else {
        const DeclTypeSpec *symType{symbol->GetType()};
        if (!symType) {
          SayBadDoControl(sourceLocation);
        } else {
          if (!symType->IsNumeric(TypeCategory::Integer)) {
            CheckDoControl(
                sourceLocation, symType->IsNumeric(TypeCategory::Real));
          }
        }
      } // No messages for INTEGER
    }
  }

  // Semantic checks for the limit and step expressions
  void CheckDoExpression(const parser::ScalarExpr &scalarExpression) {
    if (const SomeExpr * expr{GetExpr(context_, scalarExpression)}) {
      if (!ExprHasTypeCategory(*expr, TypeCategory::Integer)) {
        // No warnings or errors for type INTEGER
        const parser::CharBlock &loc{scalarExpression.thing.value().source};
        CheckDoControl(loc, ExprHasTypeCategory(*expr, TypeCategory::Real));
      }
    }
  }

  void CheckDoNormal(const parser::DoConstruct &doConstruct) {
    // C1120 -- types of DO variables must be INTEGER, extended by allowing
    // REAL and DOUBLE PRECISION
    const Bounds &bounds{GetBounds(doConstruct)};
    CheckDoVariable(bounds.name);
    CheckDoExpression(bounds.lower);
    CheckDoExpression(bounds.upper);
    if (bounds.step) {
      CheckDoExpression(*bounds.step);
      if (IsZero(*bounds.step) &&
          context_.ShouldWarn(common::UsageWarning::ZeroDoStep)) {
        context_.Say(bounds.step->thing.value().source,
            "DO step expression should not be zero"_warn_en_US);
      }
    }
  }

  void CheckDoConcurrent(const parser::DoConstruct &doConstruct) {
    auto &doStmt{
        std::get<parser::Statement<parser::NonLabelDoStmt>>(doConstruct.t)};
    currentStatementSourcePosition_ = doStmt.source;

    const parser::Block &block{std::get<parser::Block>(doConstruct.t)};
    DoConcurrentBodyEnforce doConcurrentBodyEnforce{context_, doStmt.source};
    parser::Walk(block, doConcurrentBodyEnforce);

    LabelEnforce doConcurrentLabelEnforce{context_,
        doConcurrentBodyEnforce.labels(), currentStatementSourcePosition_,
        "DO CONCURRENT"};
    parser::Walk(block, doConcurrentLabelEnforce);

    const auto &loopControl{doConstruct.GetLoopControl()};
    CheckConcurrentLoopControl(*loopControl);
    CheckLocalitySpecs(*loopControl, block);
  }

  // Return a set of symbols whose names are in a Local locality-spec.  Look
  // the names up in the scope that encloses the DO construct to avoid getting
  // the local versions of them.  Then follow the host-, use-, and
  // construct-associations to get the root symbols
  UnorderedSymbolSet GatherLocals(
      const std::list<parser::LocalitySpec> &localitySpecs) const {
    UnorderedSymbolSet symbols;
    const Scope &parentScope{
        context_.FindScope(currentStatementSourcePosition_).parent()};
    // Loop through the LocalitySpec::Local locality-specs
    for (const auto &ls : localitySpecs) {
      if (const auto *names{std::get_if<parser::LocalitySpec::Local>(&ls.u)}) {
        // Loop through the names in the Local locality-spec getting their
        // symbols
        for (const parser::Name &name : names->v) {
          if (const Symbol * symbol{parentScope.FindSymbol(name.source)}) {
            symbols.insert(ResolveAssociations(*symbol));
          }
        }
      }
    }
    return symbols;
  }

  UnorderedSymbolSet GatherSymbolsFromExpression(
      const parser::Expr &expression) const {
    UnorderedSymbolSet result;
    if (const auto *expr{GetExpr(context_, expression)}) {
      for (const Symbol &symbol : evaluate::CollectSymbols(*expr)) {
        result.insert(ResolveAssociations(symbol));
      }
    }
    return result;
  }

  // C1121 - procedures in mask must be pure
  void CheckMaskIsPure(const parser::ScalarLogicalExpr &mask) const {
    UnorderedSymbolSet references{
        GatherSymbolsFromExpression(mask.thing.thing.value())};
    for (const Symbol &ref : OrderBySourcePosition(references)) {
      if (IsProcedure(ref) && !IsPureProcedure(ref)) {
        context_.SayWithDecl(ref, parser::Unwrap<parser::Expr>(mask)->source,
            "%s mask expression may not reference impure procedure '%s'"_err_en_US,
            LoopKindName(), ref.name());
        return;
      }
    }
  }

  void CheckNoCollisions(const UnorderedSymbolSet &refs,
      const UnorderedSymbolSet &uses, parser::MessageFixedText &&errorMessage,
      const parser::CharBlock &refPosition) const {
    for (const Symbol &ref : OrderBySourcePosition(refs)) {
      if (uses.find(ref) != uses.end()) {
        context_.SayWithDecl(ref, refPosition, std::move(errorMessage),
            LoopKindName(), ref.name());
        return;
      }
    }
  }

  void HasNoReferences(const UnorderedSymbolSet &indexNames,
      const parser::ScalarIntExpr &expr) const {
    CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
        indexNames,
        "%s limit expression may not reference index variable '%s'"_err_en_US,
        expr.thing.thing.value().source);
  }

  // C1129, names in local locality-specs can't be in mask expressions
  void CheckMaskDoesNotReferenceLocal(const parser::ScalarLogicalExpr &mask,
      const UnorderedSymbolSet &localVars) const {
    CheckNoCollisions(GatherSymbolsFromExpression(mask.thing.thing.value()),
        localVars,
        "%s mask expression references variable '%s'"
        " in LOCAL locality-spec"_err_en_US,
        mask.thing.thing.value().source);
  }

  // C1129, names in local locality-specs can't be in limit or step
  // expressions
  void CheckExprDoesNotReferenceLocal(const parser::ScalarIntExpr &expr,
      const UnorderedSymbolSet &localVars) const {
    CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
        localVars,
        "%s expression references variable '%s'"
        " in LOCAL locality-spec"_err_en_US,
        expr.thing.thing.value().source);
  }

  // C1130, DEFAULT(NONE) locality requires names to be in locality-specs to
  // be used in the body of the DO loop
  void CheckDefaultNoneImpliesExplicitLocality(
      const std::list<parser::LocalitySpec> &localitySpecs,
      const parser::Block &block) const {
    bool hasDefaultNone{false};
    for (auto &ls : localitySpecs) {
      if (std::holds_alternative<parser::LocalitySpec::DefaultNone>(ls.u)) {
        if (hasDefaultNone) {
          // F'2023 C1129, you can only have one DEFAULT(NONE)
          if (context_.ShouldWarn(common::LanguageFeature::BenignRedundancy)) {
            context_.Say(currentStatementSourcePosition_,
                "Only one DEFAULT(NONE) may appear"_port_en_US);
          }
          break;
        }
        hasDefaultNone = true;
      }
    }
    if (hasDefaultNone) {
      DoConcurrentVariableEnforce doConcurrentVariableEnforce{
          context_, currentStatementSourcePosition_};
      parser::Walk(block, doConcurrentVariableEnforce);
    }
  }

  void CheckReduce(const parser::LocalitySpec::Reduce &reduce) const {
    const parser::ReductionOperator &reductionOperator{
        std::get<parser::ReductionOperator>(reduce.t)};
    // F'2023 C1132, reduction variables should have suitable intrinsic type
    for (const parser::Name &x : std::get<std::list<parser::Name>>(reduce.t)) {
      bool supportedIdentifier{false};
      if (x.symbol && x.symbol->GetType()) {
        const auto *type{x.symbol->GetType()};
        auto typeMismatch{[&](const char *suitable_types) {
          context_.Say(currentStatementSourcePosition_,
              "Reduction variable '%s' ('%s') does not have a suitable type ('%s')."_err_en_US,
              x.symbol->name(), type->AsFortran(), suitable_types);
        }};
        supportedIdentifier = true;
        switch (reductionOperator.v) {
        case parser::ReductionOperator::Operator::Plus:
        case parser::ReductionOperator::Operator::Multiply:
          if (!(type->IsNumeric(TypeCategory::Complex) ||
                  type->IsNumeric(TypeCategory::Integer) ||
                  type->IsNumeric(TypeCategory::Real))) {
            typeMismatch("COMPLEX', 'INTEGER', or 'REAL");
          }
          break;
        case parser::ReductionOperator::Operator::And:
        case parser::ReductionOperator::Operator::Or:
        case parser::ReductionOperator::Operator::Eqv:
        case parser::ReductionOperator::Operator::Neqv:
          if (type->category() != DeclTypeSpec::Category::Logical) {
            typeMismatch("LOGICAL");
          }
          break;
        case parser::ReductionOperator::Operator::Max:
        case parser::ReductionOperator::Operator::Min:
          if (!(type->IsNumeric(TypeCategory::Integer) ||
                  type->IsNumeric(TypeCategory::Real))) {
            typeMismatch("INTEGER', or 'REAL");
          }
          break;
        case parser::ReductionOperator::Operator::Iand:
        case parser::ReductionOperator::Operator::Ior:
        case parser::ReductionOperator::Operator::Ieor:
          if (!type->IsNumeric(TypeCategory::Integer)) {
            typeMismatch("INTEGER");
          }
          break;
        }
      }
      if (!supportedIdentifier) {
        context_.Say(currentStatementSourcePosition_,
            "Invalid identifier in REDUCE clause."_err_en_US);
      }
    }
  }

  // C1123, concurrent limit or step expressions can't reference index-names
  void CheckConcurrentHeader(const parser::ConcurrentHeader &header) const {
    if (const auto &mask{
            std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {
      CheckMaskIsPure(*mask);
    }
    const auto &controls{
        std::get<std::list<parser::ConcurrentControl>>(header.t)};
    UnorderedSymbolSet indexNames;
    for (const parser::ConcurrentControl &control : controls) {
      const auto &indexName{std::get<parser::Name>(control.t)};
      if (indexName.symbol) {
        indexNames.insert(*indexName.symbol);
      }
      if (isNested_) {
        CheckForImpureCall(std::get<1>(control.t));
        CheckForImpureCall(std::get<2>(control.t));
        if (const auto &stride{std::get<3>(control.t)}) {
          CheckForImpureCall(*stride);
        }
      }
    }
    if (!indexNames.empty()) {
      for (const parser::ConcurrentControl &control : controls) {
        HasNoReferences(indexNames, std::get<1>(control.t));
        HasNoReferences(indexNames, std::get<2>(control.t));
        if (const auto &intExpr{
                std::get<std::optional<parser::ScalarIntExpr>>(control.t)}) {
          const parser::Expr &expr{intExpr->thing.thing.value()};
          CheckNoCollisions(GatherSymbolsFromExpression(expr), indexNames,
              "%s step expression may not reference index variable '%s'"_err_en_US,
              expr.source);
          if (IsZero(expr)) {
            context_.Say(expr.source,
                "%s step expression may not be zero"_err_en_US, LoopKindName());
          }
        }
      }
    }
  }

  void CheckLocalitySpecs(
      const parser::LoopControl &control, const parser::Block &block) const {
    const auto &concurrent{
        std::get<parser::LoopControl::Concurrent>(control.u)};
    const auto &header{std::get<parser::ConcurrentHeader>(concurrent.t)};
    const auto &localitySpecs{
        std::get<std::list<parser::LocalitySpec>>(concurrent.t)};
    if (!localitySpecs.empty()) {
      const UnorderedSymbolSet &localVars{GatherLocals(localitySpecs)};
      for (const auto &c : GetControls(control)) {
        CheckExprDoesNotReferenceLocal(std::get<1>(c.t), localVars);
        CheckExprDoesNotReferenceLocal(std::get<2>(c.t), localVars);
        if (const auto &expr{
                std::get<std::optional<parser::ScalarIntExpr>>(c.t)}) {
          CheckExprDoesNotReferenceLocal(*expr, localVars);
        }
      }
      if (const auto &mask{
              std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {
        CheckMaskDoesNotReferenceLocal(*mask, localVars);
      }
      for (auto &ls : localitySpecs) {
        if (const auto *reduce{
                std::get_if<parser::LocalitySpec::Reduce>(&ls.u)}) {
          CheckReduce(*reduce);
        }
      }
      CheckDefaultNoneImpliesExplicitLocality(localitySpecs, block);
    }
  }

  // check constraints [C1121 .. C1130]
  void CheckConcurrentLoopControl(const parser::LoopControl &control) const {
    const auto &concurrent{
        std::get<parser::LoopControl::Concurrent>(control.u)};
    CheckConcurrentHeader(std::get<parser::ConcurrentHeader>(concurrent.t));
  }

  template <typename T> void CheckForImpureCall(const T &x) const {
    if (auto bad{FindImpureCall(context_.foldingContext(), x)}) {
      context_.Say(
          "Impure procedure '%s' may not be referenced in a %s"_err_en_US, *bad,
          LoopKindName());
    }
  }
  void CheckForImpureCall(const parser::ScalarIntExpr &x) const {
    const auto &parsedExpr{x.thing.thing.value()};
    auto oldLocation{context_.location()};
    context_.set_location(parsedExpr.source);
    if (const auto &typedExpr{parsedExpr.typedExpr}) {
      if (const auto &expr{typedExpr->v}) {
        CheckForImpureCall(*expr);
      }
    }
    context_.set_location(oldLocation);
  }

  // Each index should be used on the LHS of each assignment in a FORALL
  void CheckForallIndexesUsed(const evaluate::Assignment &assignment) {
    SymbolVector indexVars{context_.GetIndexVars(IndexVarKind::FORALL)};
    if (!indexVars.empty()) {
      UnorderedSymbolSet symbols{evaluate::CollectSymbols(assignment.lhs)};
      common::visit(
          common::visitors{
              [&](const evaluate::Assignment::BoundsSpec &spec) {
                for (const auto &bound : spec) {
// TODO: this is working around missing std::set::merge in some versions of
// clang that we are building with
#ifdef __clang__
                  auto boundSymbols{evaluate::CollectSymbols(bound)};
                  symbols.insert(boundSymbols.begin(), boundSymbols.end());
#else
                  symbols.merge(evaluate::CollectSymbols(bound));
#endif
                }
              },
              [&](const evaluate::Assignment::BoundsRemapping &remapping) {
                for (const auto &bounds : remapping) {
#ifdef __clang__
                  auto lbSymbols{evaluate::CollectSymbols(bounds.first)};
                  symbols.insert(lbSymbols.begin(), lbSymbols.end());
                  auto ubSymbols{evaluate::CollectSymbols(bounds.second)};
                  symbols.insert(ubSymbols.begin(), ubSymbols.end());
#else
                  symbols.merge(evaluate::CollectSymbols(bounds.first));
                  symbols.merge(evaluate::CollectSymbols(bounds.second));
#endif
                }
              },
              [](const auto &) {},
          },
          assignment.u);
      for (const Symbol &index : indexVars) {
        if (symbols.count(index) == 0 &&
            context_.ShouldWarn(common::UsageWarning::UnusedForallIndex)) {
          context_.Say("FORALL index variable '%s' not used on left-hand side"
                       " of assignment"_warn_en_US,
              index.name());
        }
      }
    }
  }

  // For messages where the DO loop must be DO CONCURRENT, make that explicit.
  const char *LoopKindName() const {
    return kind_ == IndexVarKind::DO ? "DO CONCURRENT" : "FORALL";
  }

  SemanticsContext &context_;
  const IndexVarKind kind_;
  parser::CharBlock currentStatementSourcePosition_;
  bool isNested_{false};
}; // class DoContext

void DoForallChecker::Enter(const parser::DoConstruct &doConstruct) {
  DoContext doContext{context_, IndexVarKind::DO, constructNesting_ > 0};
  doContext.DefineDoVariables(doConstruct);
}

void DoForallChecker::Leave(const parser::DoConstruct &doConstruct) {
  DoContext doContext{context_, IndexVarKind::DO, constructNesting_ > 0};
  ++constructNesting_;
  doContext.Check(doConstruct);
  doContext.ResetDoVariables(doConstruct);
  --constructNesting_;
}

void DoForallChecker::Enter(const parser::ForallConstruct &construct) {
  DoContext doContext{context_, IndexVarKind::FORALL, constructNesting_ > 0};
  doContext.ActivateIndexVars(GetControls(construct));
  ++constructNesting_;
  doContext.Check(construct);
}
void DoForallChecker::Leave(const parser::ForallConstruct &construct) {
  DoContext doContext{context_, IndexVarKind::FORALL, constructNesting_ > 0};
  doContext.DeactivateIndexVars(GetControls(construct));
  --constructNesting_;
}

void DoForallChecker::Enter(const parser::ForallStmt &stmt) {
  DoContext doContext{context_, IndexVarKind::FORALL, constructNesting_ > 0};
  ++constructNesting_;
  doContext.Check(stmt);
  doContext.ActivateIndexVars(GetControls(stmt));
}
void DoForallChecker::Leave(const parser::ForallStmt &stmt) {
  DoContext doContext{context_, IndexVarKind::FORALL, constructNesting_ > 0};
  doContext.DeactivateIndexVars(GetControls(stmt));
  --constructNesting_;
}
void DoForallChecker::Leave(const parser::ForallAssignmentStmt &stmt) {
  DoContext doContext{context_, IndexVarKind::FORALL, constructNesting_ > 0};
  doContext.Check(stmt);
}

template <typename A>
static parser::CharBlock GetConstructPosition(const A &a) {
  return std::get<0>(a.t).source;
}

static parser::CharBlock GetNodePosition(const ConstructNode &construct) {
  return common::visit(
      [&](const auto &x) { return GetConstructPosition(*x); }, construct);
}

void DoForallChecker::SayBadLeave(StmtType stmtType,
    const char *enclosingStmtName, const ConstructNode &construct) const {
  context_
      .Say("%s must not leave a %s statement"_err_en_US, EnumToString(stmtType),
          enclosingStmtName)
      .Attach(GetNodePosition(construct), "The construct that was left"_en_US);
}

static const parser::DoConstruct *MaybeGetDoConstruct(
    const ConstructNode &construct) {
  if (const auto *doNode{
          std::get_if<const parser::DoConstruct *>(&construct)}) {
    return *doNode;
  } else {
    return nullptr;
  }
}

static bool ConstructIsDoConcurrent(const ConstructNode &construct) {
  const parser::DoConstruct *doConstruct{MaybeGetDoConstruct(construct)};
  return doConstruct && doConstruct->IsDoConcurrent();
}

// Check that CYCLE and EXIT statements do not cause flow of control to
// leave DO CONCURRENT, CRITICAL, or CHANGE TEAM constructs.
void DoForallChecker::CheckForBadLeave(
    StmtType stmtType, const ConstructNode &construct) const {
  common::visit(common::visitors{
                    [&](const parser::DoConstruct *doConstructPtr) {
                      if (doConstructPtr->IsDoConcurrent()) {
                        // C1135 and C1167 -- CYCLE and EXIT statements can't
                        // leave a DO CONCURRENT
                        SayBadLeave(stmtType, "DO CONCURRENT", construct);
                      }
                    },
                    [&](const parser::CriticalConstruct *) {
                      // C1135 and C1168 -- similarly, for CRITICAL
                      SayBadLeave(stmtType, "CRITICAL", construct);
                    },
                    [&](const parser::ChangeTeamConstruct *) {
                      // C1135 and C1168 -- similarly, for CHANGE TEAM
                      SayBadLeave(stmtType, "CHANGE TEAM", construct);
                    },
                    [](const auto *) {},
                },
      construct);
}

static bool StmtMatchesConstruct(const parser::Name *stmtName,
    StmtType stmtType, const std::optional<parser::Name> &constructName,
    const ConstructNode &construct) {
  bool inDoConstruct{MaybeGetDoConstruct(construct) != nullptr};
  if (!stmtName) {
    return inDoConstruct; // Unlabeled statements match all DO constructs
  } else if (constructName && constructName->source == stmtName->source) {
    return stmtType == StmtType::EXIT || inDoConstruct;
  } else {
    return false;
  }
}

// C1167 Can't EXIT from a DO CONCURRENT
void DoForallChecker::CheckDoConcurrentExit(
    StmtType stmtType, const ConstructNode &construct) const {
  if (stmtType == StmtType::EXIT && ConstructIsDoConcurrent(construct)) {
    SayBadLeave(StmtType::EXIT, "DO CONCURRENT", construct);
  }
}

// Check nesting violations for a CYCLE or EXIT statement.  Loop up the
// nesting levels looking for a construct that matches the CYCLE or EXIT
// statment.  At every construct, check for a violation.  If we find a match
// without finding a violation, the check is complete.
void DoForallChecker::CheckNesting(
    StmtType stmtType, const parser::Name *stmtName) const {
  const ConstructStack &stack{context_.constructStack()};
  for (auto iter{stack.cend()}; iter-- != stack.cbegin();) {
    const ConstructNode &construct{*iter};
    const std::optional<parser::Name> &constructName{
        MaybeGetNodeName(construct)};
    if (StmtMatchesConstruct(stmtName, stmtType, constructName, construct)) {
      CheckDoConcurrentExit(stmtType, construct);
      return; // We got a match, so we're finished checking
    }
    CheckForBadLeave(stmtType, construct);
  }

  // We haven't found a match in the enclosing constructs
  if (stmtType == StmtType::EXIT) {
    context_.Say("No matching construct for EXIT statement"_err_en_US);
  } else {
    context_.Say("No matching DO construct for CYCLE statement"_err_en_US);
  }
}

// C1135 -- Nesting for CYCLE statements
void DoForallChecker::Enter(const parser::CycleStmt &cycleStmt) {
  CheckNesting(StmtType::CYCLE, common::GetPtrFromOptional(cycleStmt.v));
}

// C1167 and C1168 -- Nesting for EXIT statements
void DoForallChecker::Enter(const parser::ExitStmt &exitStmt) {
  CheckNesting(StmtType::EXIT, common::GetPtrFromOptional(exitStmt.v));
}

void DoForallChecker::Leave(const parser::AssignmentStmt &stmt) {
  const auto &variable{std::get<parser::Variable>(stmt.t)};
  context_.CheckIndexVarRedefine(variable);
}

static void CheckIfArgIsDoVar(const evaluate::ActualArgument &arg,
    const parser::CharBlock location, SemanticsContext &context) {
  common::Intent intent{arg.dummyIntent()};
  if (intent == common::Intent::Out || intent == common::Intent::InOut) {
    if (const SomeExpr * argExpr{arg.UnwrapExpr()}) {
      if (const Symbol * var{evaluate::UnwrapWholeSymbolDataRef(*argExpr)}) {
        if (intent == common::Intent::Out) {
          context.CheckIndexVarRedefine(location, *var);
        } else {
          context.WarnIndexVarRedefine(location, *var); // INTENT(INOUT)
        }
      }
    }
  }
}

// Check to see if a DO variable is being passed as an actual argument to a
// dummy argument whose intent is OUT or INOUT.  To do this, we need to find
// the expressions for actual arguments which contain DO variables.  We get the
// intents of the dummy arguments from the ProcedureRef in the "typedCall"
// field of the CallStmt which was filled in during expression checking.  At
// the same time, we need to iterate over the parser::Expr versions of the
// actual arguments to get their source locations of the arguments for the
// messages.
void DoForallChecker::Leave(const parser::CallStmt &callStmt) {
  if (const auto &typedCall{callStmt.typedCall}) {
    const auto &parsedArgs{
        std::get<std::list<parser::ActualArgSpec>>(callStmt.call.t)};
    auto parsedArgIter{parsedArgs.begin()};
    const evaluate::ActualArguments &checkedArgs{typedCall->arguments()};
    for (const auto &checkedOptionalArg : checkedArgs) {
      if (parsedArgIter == parsedArgs.end()) {
        break; // No more parsed arguments, we're done.
      }
      const auto &parsedArg{std::get<parser::ActualArg>(parsedArgIter->t)};
      ++parsedArgIter;
      if (checkedOptionalArg) {
        const evaluate::ActualArgument &checkedArg{*checkedOptionalArg};
        if (const auto *parsedExpr{
                std::get_if<common::Indirection<parser::Expr>>(&parsedArg.u)}) {
          CheckIfArgIsDoVar(checkedArg, parsedExpr->value().source, context_);
        }
      }
    }
  }
}

void DoForallChecker::Leave(const parser::ConnectSpec &connectSpec) {
  const auto *newunit{
      std::get_if<parser::ConnectSpec::Newunit>(&connectSpec.u)};
  if (newunit) {
    context_.CheckIndexVarRedefine(newunit->v.thing.thing);
  }
}

using ActualArgumentSet = std::set<evaluate::ActualArgumentRef>;

struct CollectActualArgumentsHelper
    : public evaluate::SetTraverse<CollectActualArgumentsHelper,
          ActualArgumentSet> {
  using Base = SetTraverse<CollectActualArgumentsHelper, ActualArgumentSet>;
  CollectActualArgumentsHelper() : Base{*this} {}
  using Base::operator();
  ActualArgumentSet operator()(const evaluate::ActualArgument &arg) const {
    return Combine(ActualArgumentSet{arg},
        CollectActualArgumentsHelper{}(arg.UnwrapExpr()));
  }
};

template <typename A> ActualArgumentSet CollectActualArguments(const A &x) {
  return CollectActualArgumentsHelper{}(x);
}

template ActualArgumentSet CollectActualArguments(const SomeExpr &);

void DoForallChecker::Enter(const parser::Expr &parsedExpr) { ++exprDepth_; }

void DoForallChecker::Leave(const parser::Expr &parsedExpr) {
  CHECK(exprDepth_ > 0);
  if (--exprDepth_ == 0) { // Only check top level expressions
    if (const SomeExpr * expr{GetExpr(context_, parsedExpr)}) {
      ActualArgumentSet argSet{CollectActualArguments(*expr)};
      for (const evaluate::ActualArgumentRef &argRef : argSet) {
        CheckIfArgIsDoVar(*argRef, parsedExpr.source, context_);
      }
    }
  }
}

void DoForallChecker::Leave(const parser::InquireSpec &inquireSpec) {
  const auto *intVar{std::get_if<parser::InquireSpec::IntVar>(&inquireSpec.u)};
  if (intVar) {
    const auto &scalar{std::get<parser::ScalarIntVariable>(intVar->t)};
    context_.CheckIndexVarRedefine(scalar.thing.thing);
  }
}

void DoForallChecker::Leave(const parser::IoControlSpec &ioControlSpec) {
  const auto *size{std::get_if<parser::IoControlSpec::Size>(&ioControlSpec.u)};
  if (size) {
    context_.CheckIndexVarRedefine(size->v.thing.thing);
  }
}

void DoForallChecker::Leave(const parser::OutputImpliedDo &outputImpliedDo) {
  const auto &control{std::get<parser::IoImpliedDoControl>(outputImpliedDo.t)};
  const parser::Name &name{control.name.thing.thing};
  context_.CheckIndexVarRedefine(name.source, *name.symbol);
}

void DoForallChecker::Leave(const parser::StatVariable &statVariable) {
  context_.CheckIndexVarRedefine(statVariable.v.thing.thing);
}

} // namespace Fortran::semantics