llvm/flang/lib/Optimizer/Transforms/CufImplicitDeviceGlobal.cpp

//===-- CufOpConversion.cpp -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Common/Fortran.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/allocatable.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace fir {
#define GEN_PASS_DEF_CUFIMPLICITDEVICEGLOBAL
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

namespace {

static void prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
                                         mlir::SymbolTable &symbolTable,
                                         bool onlyConstant = true) {
  auto cudaProcAttr{
      funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
  if (!cudaProcAttr || cudaProcAttr.getValue() == cuf::ProcAttribute::Host)
    return;
  for (auto addrOfOp : funcOp.getBody().getOps<fir::AddrOfOp>()) {
    if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
            addrOfOp.getSymbol().getRootReference().getValue())) {
      bool isCandidate{(onlyConstant ? globalOp.getConstant() : true) &&
                       !globalOp.getDataAttr()};
      if (isCandidate)
        globalOp.setDataAttrAttr(cuf::DataAttributeAttr::get(
            funcOp.getContext(), globalOp.getConstant()
                                     ? cuf::DataAttribute::Constant
                                     : cuf::DataAttribute::Device));
    }
  }
}

class CufImplicitDeviceGlobal
    : public fir::impl::CufImplicitDeviceGlobalBase<CufImplicitDeviceGlobal> {
public:
  void runOnOperation() override {
    mlir::Operation *op = getOperation();
    mlir::ModuleOp mod = mlir::dyn_cast<mlir::ModuleOp>(op);
    if (!mod)
      return signalPassFailure();

    mlir::SymbolTable symTable(mod);
    mod.walk([&](mlir::func::FuncOp funcOp) {
      prepareImplicitDeviceGlobals(funcOp, symTable);
      return mlir::WalkResult::advance();
    });
  }
};
} // namespace