//===-- lib/Semantics/check-case.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "check-case.h"
#include "flang/Common/idioms.h"
#include "flang/Common/reference.h"
#include "flang/Common/template.h"
#include "flang/Evaluate/fold.h"
#include "flang/Evaluate/type.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/tools.h"
#include <tuple>
namespace Fortran::semantics {
template <typename T> class CaseValues {
public:
CaseValues(SemanticsContext &c, const evaluate::DynamicType &t)
: context_{c}, caseExprType_{t} {}
void Check(const std::list<parser::CaseConstruct::Case> &cases) {
for (const parser::CaseConstruct::Case &c : cases) {
AddCase(c);
}
if (!hasErrors_) {
cases_.sort(Comparator{});
if (!AreCasesDisjoint()) { // C1149
ReportConflictingCases();
}
}
}
private:
using Value = evaluate::Scalar<T>;
void AddCase(const parser::CaseConstruct::Case &c) {
const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
const parser::CaseStmt &caseStmt{stmt.statement};
const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)};
common::visit(
common::visitors{
[&](const std::list<parser::CaseValueRange> &ranges) {
for (const auto &range : ranges) {
auto pair{ComputeBounds(range)};
if (pair.first && pair.second && *pair.first > *pair.second) {
if (context_.ShouldWarn(common::UsageWarning::EmptyCase)) {
context_.Say(stmt.source,
"CASE has lower bound greater than upper bound"_warn_en_US);
}
} else {
if constexpr (T::category == TypeCategory::Logical) { // C1148
if ((pair.first || pair.second) &&
(!pair.first || !pair.second ||
*pair.first != *pair.second)) {
context_.Say(stmt.source,
"CASE range is not allowed for LOGICAL"_err_en_US);
}
}
cases_.emplace_back(stmt);
cases_.back().lower = std::move(pair.first);
cases_.back().upper = std::move(pair.second);
}
}
},
[&](const parser::Default &) { cases_.emplace_front(stmt); },
},
selector.u);
}
std::optional<Value> GetValue(const parser::CaseValue &caseValue) {
const parser::Expr &expr{caseValue.thing.thing.value()};
auto *x{expr.typedExpr.get()};
if (x && x->v) { // C1147
auto type{x->v->GetType()};
if (type && type->category() == caseExprType_.category() &&
(type->category() != TypeCategory::Character ||
type->kind() == caseExprType_.kind())) {
parser::Messages buffer; // discarded folding messages
parser::ContextualMessages foldingMessages{expr.source, &buffer};
evaluate::FoldingContext foldingContext{
context_.foldingContext(), foldingMessages};
auto folded{evaluate::Fold(foldingContext, SomeExpr{*x->v})};
if (auto converted{evaluate::Fold(foldingContext,
evaluate::ConvertToType(T::GetType(), SomeExpr{folded}))}) {
if (auto value{evaluate::GetScalarConstantValue<T>(*converted)}) {
auto back{evaluate::Fold(foldingContext,
evaluate::ConvertToType(*type, SomeExpr{*converted}))};
if (back == folded) {
x->v = converted;
return value;
} else {
if (context_.ShouldWarn(common::UsageWarning::CaseOverflow)) {
context_.Say(expr.source,
"CASE value (%s) overflows type (%s) of SELECT CASE expression"_warn_en_US,
folded.AsFortran(), caseExprType_.AsFortran());
}
hasErrors_ = true;
return std::nullopt;
}
}
}
context_.Say(expr.source,
"CASE value (%s) must be a constant scalar"_err_en_US,
x->v->AsFortran());
} else {
std::string typeStr{type ? type->AsFortran() : "typeless"s};
context_.Say(expr.source,
"CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
typeStr, caseExprType_.AsFortran());
}
hasErrors_ = true;
}
return std::nullopt;
}
using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>;
PairOfValues ComputeBounds(const parser::CaseValueRange &range) {
return common::visit(
common::visitors{
[&](const parser::CaseValue &x) {
auto value{GetValue(x)};
return PairOfValues{value, value};
},
[&](const parser::CaseValueRange::Range &x) {
std::optional<Value> lo, hi;
if (x.lower) {
lo = GetValue(*x.lower);
}
if (x.upper) {
hi = GetValue(*x.upper);
}
if ((x.lower && !lo) || (x.upper && !hi)) {
return PairOfValues{}; // error case
}
return PairOfValues{std::move(lo), std::move(hi)};
},
},
range.u);
}
struct Case {
explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {}
bool IsDefault() const { return !lower && !upper; }
std::string AsFortran() const {
std::string result;
{
llvm::raw_string_ostream bs{result};
if (lower) {
evaluate::Constant<T>{*lower}.AsFortran(bs << '(');
if (!upper) {
bs << ':';
} else if (*lower != *upper) {
evaluate::Constant<T>{*upper}.AsFortran(bs << ':');
}
bs << ')';
} else if (upper) {
evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')';
} else {
bs << "DEFAULT";
}
}
return result;
}
const parser::Statement<parser::CaseStmt> &stmt;
std::optional<Value> lower, upper;
};
// Defines a comparator for use with std::list<>::sort().
// Returns true if and only if the highest value in range x is less
// than the least value in range y. The DEFAULT case is arbitrarily
// defined to be less than all others. When two ranges overlap,
// neither is less than the other.
struct Comparator {
bool operator()(const Case &x, const Case &y) const {
if (x.IsDefault()) {
return !y.IsDefault();
} else {
return x.upper && y.lower && *x.upper < *y.lower;
}
}
};
bool AreCasesDisjoint() const {
auto endIter{cases_.end()};
for (auto iter{cases_.begin()}; iter != endIter; ++iter) {
auto next{iter};
if (++next != endIter && !Comparator{}(*iter, *next)) {
return false;
}
}
return true;
}
// This has quadratic time, but only runs in error cases
void ReportConflictingCases() {
for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) {
parser::Message *msg{nullptr};
for (auto p{cases_.begin()}; p != cases_.end(); ++p) {
if (p->stmt.source.begin() < iter->stmt.source.begin() &&
!Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) {
if (!msg) {
msg = &context_.Say(iter->stmt.source,
"CASE %s conflicts with previous cases"_err_en_US,
iter->AsFortran());
}
msg->Attach(
p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran());
}
}
}
}
SemanticsContext &context_;
const evaluate::DynamicType &caseExprType_;
std::list<Case> cases_;
bool hasErrors_{false};
};
template <TypeCategory CAT> struct TypeVisitor {
using Result = bool;
using Types = evaluate::CategoryTypes<CAT>;
template <typename T> Result Test() {
if (T::kind == exprType.kind()) {
CaseValues<T>(context, exprType).Check(caseList);
return true;
} else {
return false;
}
}
SemanticsContext &context;
const evaluate::DynamicType &exprType;
const std::list<parser::CaseConstruct::Case> &caseList;
};
void CaseChecker::Enter(const parser::CaseConstruct &construct) {
const auto &selectCaseStmt{
std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
const auto &selectCase{selectCaseStmt.statement};
const auto &selectExpr{
std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing};
const auto *x{GetExpr(context_, selectExpr)};
if (!x) {
return; // expression semantics failed
}
if (auto exprType{x->GetType()}) {
const auto &caseList{
std::get<std::list<parser::CaseConstruct::Case>>(construct.t)};
switch (exprType->category()) {
case TypeCategory::Integer:
common::SearchTypes(
TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList});
return;
case TypeCategory::Logical:
CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType}
.Check(caseList);
return;
case TypeCategory::Character:
common::SearchTypes(
TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList});
return;
default:
break;
}
}
context_.Say(selectExpr.source,
"SELECT CASE expression must be integer, logical, or character"_err_en_US);
}
} // namespace Fortran::semantics