//===-- lib/Evaluate/fold-matmul.h ----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_
#define FORTRAN_EVALUATE_FOLD_MATMUL_H_
#include "fold-implementation.h"
namespace Fortran::evaluate {
template <typename T>
static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
using Element = typename Constant<T>::Element;
auto args{funcRef.arguments()};
CHECK(args.size() == 2);
Folder<T> folder{context};
Constant<T> *ma{folder.Folding(args[0])};
Constant<T> *mb{folder.Folding(args[1])};
if (!ma || !mb) {
return Expr<T>{std::move(funcRef)};
}
CHECK(ma->Rank() >= 1 && ma->Rank() <= 2 && mb->Rank() >= 1 &&
mb->Rank() <= 2 && (ma->Rank() == 2 || mb->Rank() == 2));
ConstantSubscript commonExtent{ma->shape().back()};
if (mb->shape().front() != commonExtent) {
context.messages().Say(
"Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US,
commonExtent, mb->shape().front());
return MakeInvalidIntrinsic(std::move(funcRef));
}
ConstantSubscript rows{ma->Rank() == 1 ? 1 : ma->shape()[0]};
ConstantSubscript columns{mb->Rank() == 1 ? 1 : mb->shape()[1]};
std::vector<Element> elements;
elements.reserve(rows * columns);
bool overflow{false};
[[maybe_unused]] const auto &rounding{
context.targetCharacteristics().roundingMode()};
// result(j,k) = SUM(A(j,:) * B(:,k))
for (ConstantSubscript ci{0}; ci < columns; ++ci) {
for (ConstantSubscript ri{0}; ri < rows; ++ri) {
ConstantSubscripts aAt{ma->lbounds()};
if (ma->Rank() == 2) {
aAt[0] += ri;
}
ConstantSubscripts bAt{mb->lbounds()};
if (mb->Rank() == 2) {
bAt[1] += ci;
}
Element sum{};
[[maybe_unused]] Element correction{};
for (ConstantSubscript j{0}; j < commonExtent; ++j) {
Element aElt{ma->At(aAt)};
Element bElt{mb->At(bAt)};
if constexpr (T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex) {
auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
if constexpr (useKahanSummation) {
auto next{correction.Add(product.value, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
} else {
auto added{sum.Add(product.value)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
}
} else if constexpr (T::category == TypeCategory::Integer) {
// Don't use Kahan summation in numeric MATMUL folding;
// the runtime doesn't use it, and results should match.
auto product{aElt.MultiplySigned(bElt)};
overflow |= product.SignedMultiplicationOverflowed();
auto added{sum.AddSigned(product.lower)};
overflow |= added.overflow;
sum = std::move(added.value);
} else {
static_assert(T::category == TypeCategory::Logical);
sum = sum.OR(aElt.AND(bElt));
}
++aAt.back();
++bAt.front();
}
elements.push_back(sum);
}
}
if (overflow &&
context.languageFeatures().ShouldWarn(
common::UsageWarning::FoldingException)) {
context.messages().Say(common::UsageWarning::FoldingException,
"MATMUL of %s data overflowed during computation"_warn_en_US,
T::AsFortran());
}
ConstantSubscripts shape;
if (ma->Rank() == 2) {
shape.push_back(rows);
}
if (mb->Rank() == 2) {
shape.push_back(columns);
}
return Expr<T>{Constant<T>{std::move(elements), std::move(shape)}};
}
} // namespace Fortran::evaluate
#endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_