llvm/flang/include/flang/Optimizer/Support/Utils.h

//===-- Optimizer/Support/Utils.h -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
#define FORTRAN_OPTIMIZER_SUPPORT_UTILS_H

#include "flang/Common/default-kinds.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/FatalError.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"

namespace fir {
/// Return the integer value of a arith::ConstantOp.
inline std::int64_t toInt(mlir::arith::ConstantOp cop) {
  return mlir::cast<mlir::IntegerAttr>(cop.getValue())
      .getValue()
      .getSExtValue();
}

// Reconstruct binding tables for dynamic dispatch.
using BindingTable = llvm::DenseMap<llvm::StringRef, unsigned>;
using BindingTables = llvm::DenseMap<llvm::StringRef, BindingTable>;

inline void buildBindingTables(BindingTables &bindingTables,
                               mlir::ModuleOp mod) {

  // The binding tables are defined in FIR after lowering inside fir.type_info
  // operations. Go through each binding tables and store the procedure name and
  // binding index for later use by the fir.dispatch conversion pattern.
  for (auto typeInfo : mod.getOps<fir::TypeInfoOp>()) {
    unsigned bindingIdx = 0;
    BindingTable bindings;
    if (typeInfo.getDispatchTable().empty()) {
      bindingTables[typeInfo.getSymName()] = bindings;
      continue;
    }
    for (auto dtEntry :
         typeInfo.getDispatchTable().front().getOps<fir::DTEntryOp>()) {
      bindings[dtEntry.getMethod()] = bindingIdx;
      ++bindingIdx;
    }
    bindingTables[typeInfo.getSymName()] = bindings;
  }
}

// Translate front-end KINDs for use in the IR and code gen.
inline std::vector<fir::KindTy>
fromDefaultKinds(const Fortran::common::IntrinsicTypeDefaultKinds &defKinds) {
  return {static_cast<fir::KindTy>(defKinds.GetDefaultKind(
              Fortran::common::TypeCategory::Character)),
          static_cast<fir::KindTy>(
              defKinds.GetDefaultKind(Fortran::common::TypeCategory::Complex)),
          static_cast<fir::KindTy>(defKinds.doublePrecisionKind()),
          static_cast<fir::KindTy>(
              defKinds.GetDefaultKind(Fortran::common::TypeCategory::Integer)),
          static_cast<fir::KindTy>(
              defKinds.GetDefaultKind(Fortran::common::TypeCategory::Logical)),
          static_cast<fir::KindTy>(
              defKinds.GetDefaultKind(Fortran::common::TypeCategory::Real))};
}

inline std::string mlirTypeToString(mlir::Type type) {
  std::string result{};
  llvm::raw_string_ostream sstream(result);
  sstream << type;
  return result;
}

inline std::string mlirTypeToIntrinsicFortran(fir::FirOpBuilder &builder,
                                              mlir::Type type,
                                              mlir::Location loc,
                                              const llvm::Twine &name) {
  if (type.isF16())
    return "REAL(KIND=2)";
  else if (type.isBF16())
    return "REAL(KIND=3)";
  else if (type.isTF32())
    return "REAL(KIND=unknown)";
  else if (type.isF32())
    return "REAL(KIND=4)";
  else if (type.isF64())
    return "REAL(KIND=8)";
  else if (type.isF80())
    return "REAL(KIND=10)";
  else if (type.isF128())
    return "REAL(KIND=16)";
  else if (type.isInteger(8))
    return "INTEGER(KIND=1)";
  else if (type.isInteger(16))
    return "INTEGER(KIND=2)";
  else if (type.isInteger(32))
    return "INTEGER(KIND=4)";
  else if (type.isInteger(64))
    return "INTEGER(KIND=8)";
  else if (type.isInteger(128))
    return "INTEGER(KIND=16)";
  else if (type == fir::ComplexType::get(builder.getContext(), 2))
    return "COMPLEX(KIND=2)";
  else if (type == fir::ComplexType::get(builder.getContext(), 3))
    return "COMPLEX(KIND=3)";
  else if (type == fir::ComplexType::get(builder.getContext(), 4))
    return "COMPLEX(KIND=4)";
  else if (type == fir::ComplexType::get(builder.getContext(), 8))
    return "COMPLEX(KIND=8)";
  else if (type == fir::ComplexType::get(builder.getContext(), 10))
    return "COMPLEX(KIND=10)";
  else if (type == fir::ComplexType::get(builder.getContext(), 16))
    return "COMPLEX(KIND=16)";
  else if (type == fir::LogicalType::get(builder.getContext(), 1))
    return "LOGICAL(KIND=1)";
  else if (type == fir::LogicalType::get(builder.getContext(), 2))
    return "LOGICAL(KIND=2)";
  else if (type == fir::LogicalType::get(builder.getContext(), 4))
    return "LOGICAL(KIND=4)";
  else if (type == fir::LogicalType::get(builder.getContext(), 8))
    return "LOGICAL(KIND=8)";
  else
    fir::emitFatalError(loc, "unsupported type in " + name + ": " +
                                 fir::mlirTypeToString(type));
}

inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
                              mlir::Location loc,
                              const llvm::Twine &intrinsicName) {
  TODO(loc,
       "intrinsic: " +
           fir::mlirTypeToIntrinsicFortran(builder, type, loc, intrinsicName) +
           " in " + intrinsicName);
}

inline void intrinsicTypeTODO2(fir::FirOpBuilder &builder, mlir::Type type1,
                               mlir::Type type2, mlir::Location loc,
                               const llvm::Twine &intrinsicName) {
  TODO(loc,
       "intrinsic: {" +
           fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
           ", " +
           fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
           "} in " + intrinsicName);
}

inline std::pair<Fortran::common::TypeCategory, KindMapping::KindTy>
mlirTypeToCategoryKind(mlir::Location loc, mlir::Type type) {
  if (type.isF16())
    return {Fortran::common::TypeCategory::Real, 2};
  else if (type.isBF16())
    return {Fortran::common::TypeCategory::Real, 3};
  else if (type.isF32())
    return {Fortran::common::TypeCategory::Real, 4};
  else if (type.isF64())
    return {Fortran::common::TypeCategory::Real, 8};
  else if (type.isF80())
    return {Fortran::common::TypeCategory::Real, 10};
  else if (type.isF128())
    return {Fortran::common::TypeCategory::Real, 16};
  else if (type.isInteger(8))
    return {Fortran::common::TypeCategory::Integer, 1};
  else if (type.isInteger(16))
    return {Fortran::common::TypeCategory::Integer, 2};
  else if (type.isInteger(32))
    return {Fortran::common::TypeCategory::Integer, 4};
  else if (type.isInteger(64))
    return {Fortran::common::TypeCategory::Integer, 8};
  else if (type.isInteger(128))
    return {Fortran::common::TypeCategory::Integer, 16};
  else if (auto complexType = mlir::dyn_cast<fir::ComplexType>(type))
    return {Fortran::common::TypeCategory::Complex, complexType.getFKind()};
  else if (auto logicalType = mlir::dyn_cast<fir::LogicalType>(type))
    return {Fortran::common::TypeCategory::Logical, logicalType.getFKind()};
  else
    fir::emitFatalError(loc,
                        "unsupported type: " + fir::mlirTypeToString(type));
}

/// Find the fir.type_info that was created for this \p recordType in \p module,
/// if any. \p  symbolTable can be provided to speed-up the lookup. This tool
/// will match record type even if they have been "altered" in type conversion
/// passes.
fir::TypeInfoOp
lookupTypeInfoOp(fir::RecordType recordType, mlir::ModuleOp module,
                 const mlir::SymbolTable *symbolTable = nullptr);

/// Find the fir.type_info named \p name in \p module, if any. \p  symbolTable
/// can be provided to speed-up the lookup. Prefer using the equivalent with a
/// RecordType argument  unless it is certain \p name has not been altered by a
/// pass rewriting fir.type (see NameUniquer::dropTypeConversionMarkers).
fir::TypeInfoOp
lookupTypeInfoOp(llvm::StringRef name, mlir::ModuleOp module,
                 const mlir::SymbolTable *symbolTable = nullptr);

/// Returns all lower bounds of \p component if it is an array component of \p
/// recordType with non default lower bounds. Returns nullopt if this is not an
/// array componnet of \p recordType or if its lower bounds are all ones.
std::optional<llvm::ArrayRef<int64_t>> getComponentLowerBoundsIfNonDefault(
    fir::RecordType recordType, llvm::StringRef component,
    mlir::ModuleOp module, const mlir::SymbolTable *symbolTable = nullptr);

} // namespace fir

#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H