//===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" namespace mlir { /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or /// `f32ApproxFunc` or `f16Func` depending on the element type and the /// fastMathFlag of that Op. The function declaration is added in case it was /// not added before. /// /// If the input values are of bf16 type (or f16 type if f16Func is empty), the /// value is first casted to f32, the function called and then the result casted /// back. /// /// Example with NVVM: /// %exp_f32 = math.exp %arg_f32 : f32 /// /// will be transformed into /// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32 /// /// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers /// to the approximate calculation function. /// /// Also example with NVVM: /// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32 /// /// will be transformed into /// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32 template <typename SourceOp> struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> { … }; } // namespace mlir #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_