llvm/flang/lib/Optimizer/Builder/TemporaryStorage.cpp

//===-- Optimizer/Builder/TemporaryStorage.cpp ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Implementation of utility data structures to create and manipulate temporary
// storages to stack Fortran values or pointers in HLFIR.
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/TemporaryStorage.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/Runtime/TemporaryStack.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"

//===----------------------------------------------------------------------===//
// fir::factory::Counter implementation.
//===----------------------------------------------------------------------===//

fir::factory::Counter::Counter(mlir::Location loc, fir::FirOpBuilder &builder,
                               mlir::Value initialValue,
                               bool canCountThroughLoops)
    : canCountThroughLoops{canCountThroughLoops}, initialValue{initialValue} {
  mlir::Type type = initialValue.getType();
  one = builder.createIntegerConstant(loc, type, 1);
  if (canCountThroughLoops) {
    index = builder.createTemporary(loc, type);
    builder.create<fir::StoreOp>(loc, initialValue, index);
  } else {
    index = initialValue;
  }
}

mlir::Value
fir::factory::Counter::getAndIncrementIndex(mlir::Location loc,
                                            fir::FirOpBuilder &builder) {
  if (canCountThroughLoops) {
    mlir::Value indexValue = builder.create<fir::LoadOp>(loc, index);
    mlir::Value newValue =
        builder.create<mlir::arith::AddIOp>(loc, indexValue, one);
    builder.create<fir::StoreOp>(loc, newValue, index);
    return indexValue;
  }
  mlir::Value indexValue = index;
  index = builder.create<mlir::arith::AddIOp>(loc, indexValue, one);
  return indexValue;
}

void fir::factory::Counter::reset(mlir::Location loc,
                                  fir::FirOpBuilder &builder) {
  if (canCountThroughLoops)
    builder.create<fir::StoreOp>(loc, initialValue, index);
  else
    index = initialValue;
}

//===----------------------------------------------------------------------===//
// fir::factory::HomogeneousScalarStack implementation.
//===----------------------------------------------------------------------===//

fir::factory::HomogeneousScalarStack::HomogeneousScalarStack(
    mlir::Location loc, fir::FirOpBuilder &builder,
    fir::SequenceType declaredType, mlir::Value extent,
    llvm::ArrayRef<mlir::Value> lengths, bool allocateOnHeap,
    bool stackThroughLoops, llvm::StringRef tempName)
    : allocateOnHeap{allocateOnHeap},
      counter{loc, builder,
              builder.createIntegerConstant(loc, builder.getIndexType(), 1),
              stackThroughLoops} {
  // Allocate the temporary storage.
  llvm::SmallVector<mlir::Value, 1> extents{extent};
  mlir::Value tempStorage;
  if (allocateOnHeap)
    tempStorage = builder.createHeapTemporary(loc, declaredType, tempName,
                                              extents, lengths);
  else
    tempStorage =
        builder.createTemporary(loc, declaredType, tempName, extents, lengths);

  mlir::Value shape = builder.genShape(loc, extents);
  temp = builder
             .create<hlfir::DeclareOp>(loc, tempStorage, tempName, shape,
                                       lengths, /*dummy_scope=*/nullptr,
                                       fir::FortranVariableFlagsAttr{})
             .getBase();
}

void fir::factory::HomogeneousScalarStack::pushValue(mlir::Location loc,
                                                     fir::FirOpBuilder &builder,
                                                     mlir::Value value) {
  hlfir::Entity entity{value};
  assert(entity.isScalar() && "cannot use inlined temp with array");
  mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder);
  hlfir::Entity tempElement = hlfir::getElementAt(
      loc, builder, hlfir::Entity{temp}, mlir::ValueRange{indexValue});
  // TODO: "copy" would probably be better than assign to ensure there are no
  // side effects (user assignments, temp, lhs finalization)?
  // This only makes a difference for derived types, and for now derived types
  // will use the runtime strategy to avoid any bad behaviors. So the todo
  // below should not get hit but is added as a remainder/safety.
  if (!entity.hasIntrinsicType())
    TODO(loc, "creating inlined temporary stack for derived types");
  builder.create<hlfir::AssignOp>(loc, value, tempElement);
}

void fir::factory::HomogeneousScalarStack::resetFetchPosition(
    mlir::Location loc, fir::FirOpBuilder &builder) {
  counter.reset(loc, builder);
}

mlir::Value
fir::factory::HomogeneousScalarStack::fetch(mlir::Location loc,
                                            fir::FirOpBuilder &builder) {
  mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder);
  hlfir::Entity tempElement = hlfir::getElementAt(
      loc, builder, hlfir::Entity{temp}, mlir::ValueRange{indexValue});
  return hlfir::loadTrivialScalar(loc, builder, tempElement);
}

void fir::factory::HomogeneousScalarStack::destroy(mlir::Location loc,
                                                   fir::FirOpBuilder &builder) {
  if (allocateOnHeap) {
    auto declare = temp.getDefiningOp<hlfir::DeclareOp>();
    assert(declare && "temp must have been declared");
    builder.create<fir::FreeMemOp>(loc, declare.getMemref());
  }
}

hlfir::Entity fir::factory::HomogeneousScalarStack::moveStackAsArrayExpr(
    mlir::Location loc, fir::FirOpBuilder &builder) {
  mlir::Value mustFree = builder.createBool(loc, allocateOnHeap);
  auto hlfirExpr = builder.create<hlfir::AsExprOp>(loc, temp, mustFree);
  return hlfir::Entity{hlfirExpr};
}

//===----------------------------------------------------------------------===//
// fir::factory::SimpleCopy implementation.
//===----------------------------------------------------------------------===//

fir::factory::SimpleCopy::SimpleCopy(mlir::Location loc,
                                     fir::FirOpBuilder &builder,
                                     hlfir::Entity source,
                                     llvm::StringRef tempName) {
  // Use hlfir.as_expr and hlfir.associate to create a copy and leave
  // bufferization deals with how best to make the copy.
  if (source.isVariable())
    source = hlfir::Entity{builder.create<hlfir::AsExprOp>(loc, source)};
  copy = hlfir::genAssociateExpr(loc, builder, source,
                                 source.getFortranElementType(), tempName);
}

void fir::factory::SimpleCopy::destroy(mlir::Location loc,
                                       fir::FirOpBuilder &builder) {
  builder.create<hlfir::EndAssociateOp>(loc, copy);
}

//===----------------------------------------------------------------------===//
// fir::factory::AnyValueStack implementation.
//===----------------------------------------------------------------------===//

fir::factory::AnyValueStack::AnyValueStack(mlir::Location loc,
                                           fir::FirOpBuilder &builder,
                                           mlir::Type valueStaticType)
    : valueStaticType{valueStaticType},
      counter{loc, builder,
              builder.createIntegerConstant(loc, builder.getI64Type(), 0),
              /*stackThroughLoops=*/true} {
  opaquePtr = fir::runtime::genCreateValueStack(loc, builder);
  // Compute the storage type. I1 are stored as fir.logical<1>. This is required
  // to use descriptor.
  mlir::Type storageType =
      hlfir::getFortranElementOrSequenceType(valueStaticType);
  mlir::Type i1Type = builder.getI1Type();
  if (storageType == i1Type)
    storageType = fir::LogicalType::get(builder.getContext(), 1);
  assert(hlfir::getFortranElementType(storageType) != i1Type &&
         "array of i1 should not be used");
  mlir::Type heapType = fir::HeapType::get(storageType);
  mlir::Type boxType;
  if (hlfir::isPolymorphicType(valueStaticType))
    boxType = fir::ClassType::get(heapType);
  else
    boxType = fir::BoxType::get(heapType);
  retValueBox = builder.createTemporary(loc, boxType);
}

void fir::factory::AnyValueStack::pushValue(mlir::Location loc,
                                            fir::FirOpBuilder &builder,
                                            mlir::Value value) {
  hlfir::Entity entity{value};
  mlir::Type storageElementType =
      hlfir::getFortranElementType(retValueBox.getType());
  auto [box, maybeCleanUp] =
      hlfir::convertToBox(loc, builder, entity, storageElementType);
  fir::runtime::genPushValue(loc, builder, opaquePtr, fir::getBase(box));
  if (maybeCleanUp)
    (*maybeCleanUp)();
}

void fir::factory::AnyValueStack::resetFetchPosition(
    mlir::Location loc, fir::FirOpBuilder &builder) {
  counter.reset(loc, builder);
}

mlir::Value fir::factory::AnyValueStack::fetch(mlir::Location loc,
                                               fir::FirOpBuilder &builder) {
  mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder);
  fir::runtime::genValueAt(loc, builder, opaquePtr, indexValue, retValueBox);
  // Dereference the allocatable "retValueBox", and load if trivial scalar
  // value.
  mlir::Value result =
      hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{retValueBox});
  if (valueStaticType != result.getType()) {
    // Cast back saved simple scalars stored with another type to their original
    // type (like i1).
    if (fir::isa_trivial(valueStaticType))
      return builder.createConvert(loc, valueStaticType, result);
    // Memory type mismatches (e.g. fir.ref vs fir.heap) or hlfir.expr vs
    // variable type mismatches are OK, but the base Fortran type must be the
    // same.
    assert(hlfir::getFortranElementOrSequenceType(valueStaticType) ==
               hlfir::getFortranElementOrSequenceType(result.getType()) &&
           "non trivial values must be saved with their original type");
  }
  return result;
}

void fir::factory::AnyValueStack::destroy(mlir::Location loc,
                                          fir::FirOpBuilder &builder) {
  fir::runtime::genDestroyValueStack(loc, builder, opaquePtr);
}

//===----------------------------------------------------------------------===//
// fir::factory::AnyVariableStack implementation.
//===----------------------------------------------------------------------===//

fir::factory::AnyVariableStack::AnyVariableStack(mlir::Location loc,
                                                 fir::FirOpBuilder &builder,
                                                 mlir::Type variableStaticType)
    : variableStaticType{variableStaticType},
      counter{loc, builder,
              builder.createIntegerConstant(loc, builder.getI64Type(), 0),
              /*stackThroughLoops=*/true} {
  opaquePtr = fir::runtime::genCreateDescriptorStack(loc, builder);
  mlir::Type storageType =
      hlfir::getFortranElementOrSequenceType(variableStaticType);
  mlir::Type ptrType = fir::PointerType::get(storageType);
  mlir::Type boxType;
  if (hlfir::isPolymorphicType(variableStaticType))
    boxType = fir::ClassType::get(ptrType);
  else
    boxType = fir::BoxType::get(ptrType);
  retValueBox = builder.createTemporary(loc, boxType);
}

void fir::factory::AnyVariableStack::pushValue(mlir::Location loc,
                                               fir::FirOpBuilder &builder,
                                               mlir::Value variable) {
  hlfir::Entity entity{variable};
  mlir::Type storageElementType =
      hlfir::getFortranElementType(retValueBox.getType());
  auto [box, maybeCleanUp] =
      hlfir::convertToBox(loc, builder, entity, storageElementType);
  fir::runtime::genPushDescriptor(loc, builder, opaquePtr, fir::getBase(box));
  if (maybeCleanUp)
    (*maybeCleanUp)();
}

void fir::factory::AnyVariableStack::resetFetchPosition(
    mlir::Location loc, fir::FirOpBuilder &builder) {
  counter.reset(loc, builder);
}

mlir::Value fir::factory::AnyVariableStack::fetch(mlir::Location loc,
                                                  fir::FirOpBuilder &builder) {
  mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder);
  fir::runtime::genDescriptorAt(loc, builder, opaquePtr, indexValue,
                                retValueBox);
  hlfir::Entity retBox{builder.create<fir::LoadOp>(loc, retValueBox)};
  // The runtime always tracks variable as address, but the form of the variable
  // that was saved may be different (raw address, fir.boxchar), ensure
  // the returned variable has the same form of the one that was saved.
  if (mlir::isa<fir::BaseBoxType>(variableStaticType))
    return builder.createConvert(loc, variableStaticType, retBox);
  if (mlir::isa<fir::BoxCharType>(variableStaticType))
    return hlfir::genVariableBoxChar(loc, builder, retBox);
  mlir::Value rawAddr = genVariableRawAddress(loc, builder, retBox);
  return builder.createConvert(loc, variableStaticType, rawAddr);
}

void fir::factory::AnyVariableStack::destroy(mlir::Location loc,
                                             fir::FirOpBuilder &builder) {
  fir::runtime::genDestroyDescriptorStack(loc, builder, opaquePtr);
}

//===----------------------------------------------------------------------===//
// fir::factory::AnyVectorSubscriptStack implementation.
//===----------------------------------------------------------------------===//

fir::factory::AnyVectorSubscriptStack::AnyVectorSubscriptStack(
    mlir::Location loc, fir::FirOpBuilder &builder,
    mlir::Type variableStaticType, bool shapeCanBeSavedAsRegister, int rank)
    : AnyVariableStack{loc, builder, variableStaticType} {
  if (shapeCanBeSavedAsRegister) {
    shapeTemp = std::make_unique<TemporaryStorage>(SSARegister{});
    return;
  }
  // The shape will be tracked as the dimension inside a descriptor because
  // that is the easiest from a lowering point of view, and this is an
  // edge case situation that will probably not very well be exercised.
  mlir::Type type =
      fir::BoxType::get(builder.getVarLenSeqTy(builder.getI32Type(), rank));
  boxType = type;
  shapeTemp =
      std::make_unique<TemporaryStorage>(AnyVariableStack{loc, builder, type});
}

void fir::factory::AnyVectorSubscriptStack::pushShape(
    mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value shape) {
  if (boxType) {
    // The shape is saved as a dimensions inside a descriptors.
    mlir::Type refType = fir::ReferenceType::get(
        hlfir::getFortranElementOrSequenceType(*boxType));
    mlir::Value null = builder.createNullConstant(loc, refType);
    mlir::Value descriptor =
        builder.create<fir::EmboxOp>(loc, *boxType, null, shape);
    shapeTemp->pushValue(loc, builder, descriptor);
    return;
  }
  // Otherwise, simply keep track of the fir.shape itself, it is invariant.
  shapeTemp->cast<SSARegister>().pushValue(loc, builder, shape);
}

void fir::factory::AnyVectorSubscriptStack::resetFetchPosition(
    mlir::Location loc, fir::FirOpBuilder &builder) {
  static_cast<AnyVariableStack *>(this)->resetFetchPosition(loc, builder);
  shapeTemp->resetFetchPosition(loc, builder);
}

mlir::Value
fir::factory::AnyVectorSubscriptStack::fetchShape(mlir::Location loc,
                                                  fir::FirOpBuilder &builder) {
  if (boxType) {
    hlfir::Entity descriptor{shapeTemp->fetch(loc, builder)};
    return hlfir::genShape(loc, builder, descriptor);
  }
  return shapeTemp->cast<SSARegister>().fetch(loc, builder);
}

void fir::factory::AnyVectorSubscriptStack::destroy(
    mlir::Location loc, fir::FirOpBuilder &builder) {
  static_cast<AnyVariableStack *>(this)->destroy(loc, builder);
  shapeTemp->destroy(loc, builder);
}