//===-- lib/Evaluate/fold-reduction.h -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
#include "fold-implementation.h"
namespace Fortran::evaluate {
// DOT_PRODUCT
template <typename T>
static Expr<T> FoldDotProduct(
FoldingContext &context, FunctionRef<T> &&funcRef) {
using Element = typename Constant<T>::Element;
auto args{funcRef.arguments()};
CHECK(args.size() == 2);
Folder<T> folder{context};
Constant<T> *va{folder.Folding(args[0])};
Constant<T> *vb{folder.Folding(args[1])};
if (va && vb) {
CHECK(va->Rank() == 1 && vb->Rank() == 1);
if (va->size() != vb->size()) {
context.messages().Say(
"Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US,
va->size(), vb->size());
return MakeInvalidIntrinsic(std::move(funcRef));
}
Element sum{};
bool overflow{false};
if constexpr (T::category == TypeCategory::Complex) {
std::vector<Element> conjugates;
for (const Element &x : va->values()) {
conjugates.emplace_back(x.CONJG());
}
Constant<T> conjgA{
std::move(conjugates), ConstantSubscripts{va->shape()}};
Expr<T> products{Fold(
context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
[[maybe_unused]] Element correction{};
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
auto next{correction.Add(x, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
}
}
} else if constexpr (T::category == TypeCategory::Logical) {
Expr<T> conjunctions{Fold(context,
Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And,
Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})};
Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))};
for (const Element &x : cConjunctions.values()) {
if (x.IsTrue()) {
sum = Element{true};
break;
}
}
} else if constexpr (T::category == TypeCategory::Integer) {
Expr<T> products{
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
for (const Element &x : cProducts.values()) {
auto next{sum.AddSigned(x)};
overflow |= next.overflow;
sum = std::move(next.value);
}
} else {
static_assert(T::category == TypeCategory::Real);
Expr<T> products{
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
[[maybe_unused]] Element correction{};
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
auto next{correction.Add(x, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
}
}
}
if (overflow &&
context.languageFeatures().ShouldWarn(
common::UsageWarning::FoldingException)) {
context.messages().Say(common::UsageWarning::FoldingException,
"DOT_PRODUCT of %s data overflowed during computation"_warn_en_US,
T::AsFortran());
}
return Expr<T>{Constant<T>{std::move(sum)}};
}
return Expr<T>{std::move(funcRef)};
}
// Fold and validate a DIM= argument. Returns false on error.
bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &,
ActualArguments &, std::optional<int> dimIndex, int rank);
// Fold and validate a MASK= argument. Return null on error, absent MASK=, or
// non-constant MASK=.
Constant<LogicalResult> *GetReductionMASK(
std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape,
FoldingContext &);
// Common preprocessing for reduction transformational intrinsic function
// folding. If the intrinsic can have DIM= &/or MASK= arguments, extract
// and check them. If a MASK= is present, apply it to the array data and
// substitute replacement values for elements corresponding to .FALSE. in
// the mask. If the result is present, the intrinsic call can be folded.
template <typename T> struct ArrayAndMask {
Constant<T> array;
Constant<LogicalResult> mask;
};
template <typename T>
static std::optional<ArrayAndMask<T>> ProcessReductionArgs(
FoldingContext &context, ActualArguments &arg, std::optional<int> &dim,
int arrayIndex, std::optional<int> dimIndex = std::nullopt,
std::optional<int> maskIndex = std::nullopt) {
if (arg.empty()) {
return std::nullopt;
}
Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])};
if (!folded || folded->Rank() < 1) {
return std::nullopt;
}
if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) {
return std::nullopt;
}
std::size_t n{folded->size()};
std::vector<Scalar<LogicalResult>> maskElement;
if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() &&
arg[*maskIndex]) {
if (const Constant<LogicalResult> *origMask{
GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) {
if (auto scalarMask{origMask->GetScalarValue()}) {
maskElement =
std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue());
} else {
maskElement = origMask->values();
}
} else {
return std::nullopt;
}
} else {
maskElement = std::vector<Scalar<LogicalResult>>(n, true);
}
return ArrayAndMask<T>{Constant<T>(*folded),
Constant<LogicalResult>{
std::move(maskElement), ConstantSubscripts{folded->shape()}}};
}
// Generalized reduction to an array of one dimension fewer (w/ DIM=)
// or to a scalar (w/o DIM=). The ACCUMULATOR type must define
// operator()(Scalar<T> &, const ConstantSubscripts &, bool first)
// and Done(Scalar<T> &).
template <typename T, typename ACCUMULATOR, typename ARRAY>
static Constant<T> DoReduction(const Constant<ARRAY> &array,
const Constant<LogicalResult> &mask, std::optional<int> &dim,
const Scalar<T> &identity, ACCUMULATOR &accumulator) {
ConstantSubscripts at{array.lbounds()};
ConstantSubscripts maskAt{mask.lbounds()};
std::vector<typename Constant<T>::Element> elements;
ConstantSubscripts resultShape; // empty -> scalar
if (dim) { // DIM= is present, so result is an array
resultShape = array.shape();
resultShape.erase(resultShape.begin() + (*dim - 1));
ConstantSubscript dimExtent{array.shape().at(*dim - 1)};
CHECK(dimExtent == mask.shape().at(*dim - 1));
ConstantSubscript &dimAt{at[*dim - 1]};
ConstantSubscript dimLbound{dimAt};
ConstantSubscript &maskDimAt{maskAt[*dim - 1]};
ConstantSubscript maskDimLbound{maskDimAt};
for (auto n{GetSize(resultShape)}; n-- > 0;
array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
elements.push_back(identity);
if (dimExtent > 0) {
dimAt = dimLbound;
maskDimAt = maskDimLbound;
bool firstUnmasked{true};
for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) {
if (mask.At(maskAt).IsTrue()) {
accumulator(elements.back(), at, firstUnmasked);
firstUnmasked = false;
}
}
--dimAt, --maskDimAt;
}
accumulator.Done(elements.back());
}
} else { // no DIM=, result is scalar
elements.push_back(identity);
bool firstUnmasked{true};
for (auto n{array.size()}; n-- > 0;
array.IncrementSubscripts(at), mask.IncrementSubscripts(maskAt)) {
if (mask.At(maskAt).IsTrue()) {
accumulator(elements.back(), at, firstUnmasked);
firstUnmasked = false;
}
}
accumulator.Done(elements.back());
}
if constexpr (T::category == TypeCategory::Character) {
return {static_cast<ConstantSubscript>(identity.size()),
std::move(elements), std::move(resultShape)};
} else {
return {std::move(elements), std::move(resultShape)};
}
}
// MAXVAL & MINVAL
template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
public:
MaxvalMinvalAccumulator(
RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
: opr_{opr}, context_{context}, array_{array} {};
void operator()(Scalar<T> &element, const ConstantSubscripts &at,
[[maybe_unused]] bool firstUnmasked) const {
auto aAt{array_.At(at)};
if constexpr (ABS) {
aAt = aAt.ABS();
}
if constexpr (T::category == TypeCategory::Real) {
if (firstUnmasked || element.IsNotANumber()) {
// Return NaN if and only if all unmasked elements are NaNs and
// at least one unmasked element is visible.
element = aAt;
return;
}
}
Expr<LogicalResult> test{PackageRelation(
opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
auto folded{GetScalarConstantValue<LogicalResult>(
test.Rewrite(context_, std::move(test)))};
CHECK(folded.has_value());
if (folded->IsTrue()) {
element = aAt;
}
}
void Done(Scalar<T> &) const {}
private:
RelationalOperator opr_;
FoldingContext &context_;
const Constant<T> &array_;
};
template <typename T>
static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
RelationalOperator opr, const Scalar<T> &identity) {
static_assert(T::category == TypeCategory::Integer ||
T::category == TypeCategory::Real ||
T::category == TypeCategory::Character);
std::optional<int> dim;
if (std::optional<ArrayAndMask<T>> arrayAndMask{
ProcessReductionArgs<T>(context, ref.arguments(), dim,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
MaxvalMinvalAccumulator accumulator{opr, context, arrayAndMask->array};
return Expr<T>{DoReduction<T>(
arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};
}
// PRODUCT
template <typename T> class ProductAccumulator {
public:
ProductAccumulator(const Constant<T> &array) : array_{array} {}
void operator()(
Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
if constexpr (T::category == TypeCategory::Integer) {
auto prod{element.MultiplySigned(array_.At(at))};
overflow_ |= prod.SignedMultiplicationOverflowed();
element = prod.lower;
} else { // Real & Complex
auto prod{element.Multiply(array_.At(at))};
overflow_ |= prod.flags.test(RealFlag::Overflow);
element = prod.value;
}
}
bool overflow() const { return overflow_; }
void Done(Scalar<T> &) const {}
private:
const Constant<T> &array_;
bool overflow_{false};
};
template <typename T>
static Expr<T> FoldProduct(
FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
static_assert(T::category == TypeCategory::Integer ||
T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex);
std::optional<int> dim;
if (std::optional<ArrayAndMask<T>> arrayAndMask{
ProcessReductionArgs<T>(context, ref.arguments(), dim,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
ProductAccumulator accumulator{arrayAndMask->array};
auto result{Expr<T>{DoReduction<T>(
arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
if (accumulator.overflow() &&
context.languageFeatures().ShouldWarn(
common::UsageWarning::FoldingException)) {
context.messages().Say(common::UsageWarning::FoldingException,
"PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
}
return result;
}
return Expr<T>{std::move(ref)};
}
// SUM
template <typename T> class SumAccumulator {
using Element = typename Constant<T>::Element;
public:
SumAccumulator(const Constant<T> &array, Rounding rounding)
: array_{array}, rounding_{rounding} {}
void operator()(
Element &element, const ConstantSubscripts &at, bool /*first*/) {
if constexpr (T::category == TypeCategory::Integer) {
auto sum{element.AddSigned(array_.At(at))};
overflow_ |= sum.overflow;
element = sum.value;
} else { // Real & Complex: use Kahan summation
auto next{array_.At(at).Add(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
// correction = (sum - element) - next; algebraically zero
correction_ = sum.value.Subtract(element, rounding_)
.value.Subtract(next.value, rounding_)
.value;
element = sum.value;
}
}
bool overflow() const { return overflow_; }
void Done([[maybe_unused]] Element &element) {
if constexpr (T::category != TypeCategory::Integer) {
auto corrected{element.Add(correction_, rounding_)};
overflow_ |= corrected.flags.test(RealFlag::Overflow);
correction_ = Scalar<T>{};
element = corrected.value;
}
}
private:
const Constant<T> &array_;
Rounding rounding_;
bool overflow_{false};
Element correction_{};
};
template <typename T>
static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
static_assert(T::category == TypeCategory::Integer ||
T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex);
using Element = typename Constant<T>::Element;
std::optional<int> dim;
Element identity{};
if (std::optional<ArrayAndMask<T>> arrayAndMask{
ProcessReductionArgs<T>(context, ref.arguments(), dim,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
SumAccumulator accumulator{
arrayAndMask->array, context.targetCharacteristics().roundingMode()};
auto result{Expr<T>{DoReduction<T>(
arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
if (accumulator.overflow() &&
context.languageFeatures().ShouldWarn(
common::UsageWarning::FoldingException)) {
context.messages().Say(common::UsageWarning::FoldingException,
"SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
}
return result;
}
return Expr<T>{std::move(ref)};
}
// Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY
template <typename T> class OperationAccumulator {
public:
OperationAccumulator(const Constant<T> &array,
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
: array_{array}, operation_{operation} {}
void operator()(
Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
element = (element.*operation_)(array_.At(at));
}
void Done(Scalar<T> &) const {}
private:
const Constant<T> &array_;
Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const;
};
} // namespace Fortran::evaluate
#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_