llvm/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h

//===-- FIROpPatterns.h -- FIR operation conversion patterns ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_OPTIMIZER_CODEGEN_FIROPPATTERNS_H
#define FORTRAN_OPTIMIZER_CODEGEN_FIROPPATTERNS_H

#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

namespace fir {

struct FIRToLLVMPassOptions;

static constexpr unsigned defaultAddressSpace = 0u;

class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
public:
  ConvertFIRToLLVMPattern(llvm::StringRef rootOpName,
                          mlir::MLIRContext *context,
                          const fir::LLVMTypeConverter &typeConverter,
                          const fir::FIRToLLVMPassOptions &options,
                          mlir::PatternBenefit benefit = 1);

protected:
  mlir::Type convertType(mlir::Type ty) const {
    return lowerTy().convertType(ty);
  }

  // Convert FIR type to LLVM without turning fir.box<T> into memory
  // reference.
  mlir::Type convertObjectType(mlir::Type firType) const;

  mlir::LLVM::ConstantOp
  genI32Constant(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
                 int value) const;

  mlir::LLVM::ConstantOp
  genConstantOffset(mlir::Location loc,
                    mlir::ConversionPatternRewriter &rewriter,
                    int offset) const;

  /// Perform an extension or truncation as needed on an integer value. Lowering
  /// to the specific target may involve some sign-extending or truncation of
  /// values, particularly to fit them from abstract box types to the
  /// appropriate reified structures.
  mlir::Value integerCast(mlir::Location loc,
                          mlir::ConversionPatternRewriter &rewriter,
                          mlir::Type ty, mlir::Value val,
                          bool fold = false) const;

  struct TypePair {
    mlir::Type fir;
    mlir::Type llvm;
  };

  TypePair getBoxTypePair(mlir::Type firBoxTy) const;

  /// Construct code sequence to extract the specific value from a `fir.box`.
  mlir::Value getValueFromBox(mlir::Location loc, TypePair boxTy,
                              mlir::Value box, mlir::Type resultTy,
                              mlir::ConversionPatternRewriter &rewriter,
                              int boxValue) const;

  /// Method to construct code sequence to get the triple for dimension `dim`
  /// from a box.
  llvm::SmallVector<mlir::Value, 3>
  getDimsFromBox(mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys,
                 TypePair boxTy, mlir::Value box, mlir::Value dim,
                 mlir::ConversionPatternRewriter &rewriter) const;

  llvm::SmallVector<mlir::Value, 3>
  getDimsFromBox(mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys,
                 TypePair boxTy, mlir::Value box, int dim,
                 mlir::ConversionPatternRewriter &rewriter) const;

  mlir::Value
  loadDimFieldFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box,
                      mlir::Value dim, int off, mlir::Type ty,
                      mlir::ConversionPatternRewriter &rewriter) const;

  mlir::Value
  getDimFieldFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box,
                     int dim, int off, mlir::Type ty,
                     mlir::ConversionPatternRewriter &rewriter) const;

  mlir::Value getStrideFromBox(mlir::Location loc, TypePair boxTy,
                               mlir::Value box, unsigned dim,
                               mlir::ConversionPatternRewriter &rewriter) const;

  /// Read base address from a fir.box. Returned address has type ty.
  mlir::Value
  getBaseAddrFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box,
                     mlir::ConversionPatternRewriter &rewriter) const;

  mlir::Value
  getElementSizeFromBox(mlir::Location loc, mlir::Type resultTy, TypePair boxTy,
                        mlir::Value box,
                        mlir::ConversionPatternRewriter &rewriter) const;

  mlir::Value getRankFromBox(mlir::Location loc, TypePair boxTy,
                             mlir::Value box,
                             mlir::ConversionPatternRewriter &rewriter) const;

  mlir::Value getExtraFromBox(mlir::Location loc, TypePair boxTy,
                              mlir::Value box,
                              mlir::ConversionPatternRewriter &rewriter) const;

  // Get the element type given an LLVM type that is of the form
  // (array|struct|vector)+ and the provided indexes.
  mlir::Type getBoxEleTy(mlir::Type type,
                         llvm::ArrayRef<std::int64_t> indexes) const;

  // Return LLVM type of the object described by a fir.box of \p boxType.
  mlir::Type getLlvmObjectTypeFromBoxType(mlir::Type boxType) const;

  /// Read the address of the type descriptor from a box.
  mlir::Value
  loadTypeDescAddress(mlir::Location loc, TypePair boxTy, mlir::Value box,
                      mlir::ConversionPatternRewriter &rewriter) const;

  // Load the attribute from the \p box and perform a check against \p maskValue
  // The final comparison is implemented as `(attribute & maskValue) != 0`.
  mlir::Value genBoxAttributeCheck(mlir::Location loc, TypePair boxTy,
                                   mlir::Value box,
                                   mlir::ConversionPatternRewriter &rewriter,
                                   unsigned maskValue) const;

  /// Compute the descriptor size in bytes. The result is not guaranteed to be a
  /// compile time constant if the box is for an assumed rank, in which case the
  /// box rank will be read.
  mlir::Value computeBoxSize(mlir::Location, TypePair boxTy, mlir::Value box,
                             mlir::ConversionPatternRewriter &rewriter) const;

  template <typename... ARGS>
  mlir::LLVM::GEPOp genGEP(mlir::Location loc, mlir::Type ty,
                           mlir::ConversionPatternRewriter &rewriter,
                           mlir::Value base, ARGS... args) const {
    llvm::SmallVector<mlir::LLVM::GEPArg> cv = {args...};
    auto llvmPtrTy =
        mlir::LLVM::LLVMPointerType::get(ty.getContext(), /*addressSpace=*/0);
    return rewriter.create<mlir::LLVM::GEPOp>(loc, llvmPtrTy, ty, base, cv);
  }

  // Find the Block in which the alloca should be inserted.
  // The order to recursively find the proper block:
  // 1. An OpenMP Op that will be outlined.
  // 2. An OpenMP or OpenACC Op with one or more regions holding executable
  // code.
  // 3. A LLVMFuncOp
  // 4. The first ancestor that is one of the above.
  mlir::Block *getBlockForAllocaInsert(mlir::Operation *op,
                                       mlir::Region *parentRegion) const;

  // Generate an alloca of size 1 for an object of type \p llvmObjectTy in the
  // allocation address space provided for the architecture in the DataLayout
  // specification. If the address space is different from the devices
  // program address space we perform a cast. In the case of most architectures
  // the program and allocation address space will be the default of 0 and no
  // cast will be emitted.
  mlir::Value
  genAllocaAndAddrCastWithType(mlir::Location loc, mlir::Type llvmObjectTy,
                               unsigned alignment,
                               mlir::ConversionPatternRewriter &rewriter) const;

  const fir::LLVMTypeConverter &lowerTy() const {
    return *static_cast<const fir::LLVMTypeConverter *>(
        this->getTypeConverter());
  }

  void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
                     mlir::Type baseFIRType, mlir::Type accessFIRType,
                     mlir::LLVM::GEPOp gep) const {
    lowerTy().attachTBAATag(op, baseFIRType, accessFIRType, gep);
  }

  unsigned
  getAllocaAddressSpace(mlir::ConversionPatternRewriter &rewriter) const;

  unsigned
  getProgramAddressSpace(mlir::ConversionPatternRewriter &rewriter) const;

  const fir::FIRToLLVMPassOptions &options;

  using ConvertToLLVMPattern::match;
  using ConvertToLLVMPattern::matchAndRewrite;
};

template <typename SourceOp>
class FIROpConversion : public ConvertFIRToLLVMPattern {
public:
  using OpAdaptor = typename SourceOp::Adaptor;

  explicit FIROpConversion(const LLVMTypeConverter &typeConverter,
                           const fir::FIRToLLVMPassOptions &options,
                           mlir::PatternBenefit benefit = 1)
      : ConvertFIRToLLVMPattern(SourceOp::getOperationName(),
                                &typeConverter.getContext(), typeConverter,
                                options, benefit) {}

  /// Wrappers around the RewritePattern methods that pass the derived op type.
  void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
               mlir::ConversionPatternRewriter &rewriter) const final {
    rewrite(mlir::cast<SourceOp>(op),
            OpAdaptor(operands, mlir::cast<SourceOp>(op)), rewriter);
  }
  llvm::LogicalResult match(mlir::Operation *op) const final {
    return match(mlir::cast<SourceOp>(op));
  }
  llvm::LogicalResult
  matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
                  mlir::ConversionPatternRewriter &rewriter) const final {
    return matchAndRewrite(mlir::cast<SourceOp>(op),
                           OpAdaptor(operands, mlir::cast<SourceOp>(op)),
                           rewriter);
  }

  /// Rewrite and Match methods that operate on the SourceOp type. These must be
  /// overridden by the derived pattern class.
  virtual llvm::LogicalResult match(SourceOp op) const {
    llvm_unreachable("must override match or matchAndRewrite");
  }
  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
                       mlir::ConversionPatternRewriter &rewriter) const {
    llvm_unreachable("must override rewrite or matchAndRewrite");
  }
  virtual llvm::LogicalResult
  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                  mlir::ConversionPatternRewriter &rewriter) const {
    if (mlir::failed(match(op)))
      return mlir::failure();
    rewrite(op, adaptor, rewriter);
    return mlir::success();
  }

private:
  using ConvertFIRToLLVMPattern::matchAndRewrite;
  using ConvertToLLVMPattern::match;
};

/// FIR conversion pattern template
template <typename FromOp>
class FIROpAndTypeConversion : public FIROpConversion<FromOp> {
public:
  using FIROpConversion<FromOp>::FIROpConversion;
  using OpAdaptor = typename FromOp::Adaptor;

  llvm::LogicalResult
  matchAndRewrite(FromOp op, OpAdaptor adaptor,
                  mlir::ConversionPatternRewriter &rewriter) const final {
    mlir::Type ty = this->convertType(op.getType());
    return doRewrite(op, ty, adaptor, rewriter);
  }

  virtual llvm::LogicalResult
  doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor,
            mlir::ConversionPatternRewriter &rewriter) const = 0;
};

} // namespace fir

#endif // FORTRAN_OPTIMIZER_CODEGEN_FIROPPATTERNS_H