//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOFUNCS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir usingnamespacemlir; #define DEBUG_TYPE … #define DBGS() … namespace { // Pattern to convert vector operations to scalar operations. template <typename Op> struct VecOpToScalarOp : public OpRewritePattern<Op> { … }; // Callback type for getting pre-generated FuncOp implementing // an operation of the given type. GetFuncCallbackTy; // Pattern to convert scalar IPowIOp into a call of outlined // software implementation. class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> { … }; // Pattern to convert scalar FPowIOp into a call of outlined // software implementation. class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> { … }; // Pattern to convert scalar ctlz into a call of outlined software // implementation. class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> { … }; } // namespace template <typename Op> LogicalResult VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { … } static FunctionType getElementalFuncTypeForOp(Operation *op) { … } /// Create linkonce_odr function to implement the power function with /// the given \p elementType type inside \p module. The \p elementType /// must be IntegerType, an the created function has /// 'IntegerType (*)(IntegerType, IntegerType)' function type. /// /// template <typename T> /// T __mlir_math_ipowi_*(T b, T p) { /// if (p == T(0)) /// return T(1); /// if (p < T(0)) { /// if (b == T(0)) /// return T(1) / T(0); // trigger div-by-zero /// if (b == T(1)) /// return T(1); /// if (b == T(-1)) { /// if (p & T(1)) /// return T(-1); /// return T(1); /// } /// return T(0); /// } /// T result = T(1); /// while (true) { /// if (p & T(1)) /// result *= b; /// p >>= T(1); /// if (p == T(0)) /// return result; /// b *= b; /// } /// } static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { … } /// Convert IPowI into a call to a local function implementing /// the power operation. The local function computes a scalar result, /// so vector forms of IPowI are linearized. LogicalResult IPowIOpLowering::matchAndRewrite(math::IPowIOp op, PatternRewriter &rewriter) const { … } /// Create linkonce_odr function to implement the power function with /// the given \p funcType type inside \p module. The \p funcType must be /// 'FloatType (*)(FloatType, IntegerType)' function type. /// /// template <typename T> /// Tb __mlir_math_fpowi_*(Tb b, Tp p) { /// if (p == Tp{0}) /// return Tb{1}; /// bool isNegativePower{p < Tp{0}} /// bool isMin{p == std::numeric_limits<Tp>::min()}; /// if (isMin) { /// p = std::numeric_limits<Tp>::max(); /// } else if (isNegativePower) { /// p = -p; /// } /// Tb result = Tb{1}; /// Tb origBase = Tb{b}; /// while (true) { /// if (p & Tp{1}) /// result *= b; /// p >>= Tp{1}; /// if (p == Tp{0}) /// break; /// b *= b; /// } /// if (isMin) { /// result *= origBase; /// } /// if (isNegativePower) { /// result = Tb{1} / result; /// } /// return result; /// } static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType) { … } /// Convert FPowI into a call to a local function implementing /// the power operation. The local function computes a scalar result, /// so vector forms of FPowI are linearized. LogicalResult FPowIOpLowering::matchAndRewrite(math::FPowIOp op, PatternRewriter &rewriter) const { … } /// Create function to implement the ctlz function the given \p elementType type /// inside \p module. The \p elementType must be IntegerType, an the created /// function has 'IntegerType (*)(IntegerType)' function type. /// /// template <typename T> /// T __mlir_math_ctlz_*(T x) { /// bits = sizeof(x) * 8; /// if (x == 0) /// return bits; /// /// uint32_t n = 0; /// for (int i = 1; i < bits; ++i) { /// if (x < 0) continue; /// n++; /// x <<= 1; /// } /// return n; /// } /// /// Converts to (for i32): /// /// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 { /// %c_32 = arith.constant 32 : index /// %c_0 = arith.constant 0 : i32 /// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1 /// %out = scf.if %arg_eq_zero { /// scf.yield %c_32 : i32 /// } else { /// %c_1index = arith.constant 1 : index /// %c_1i32 = arith.constant 1 : i32 /// %n = arith.constant 0 : i32 /// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index /// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) { /// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32 /// %yield_val = scf.if %cond { /// scf.yield %arg_iter, %n_iter : i32, i32 /// } else { /// %arg_next = arith.shli %arg_iter, %c_1i32 : i32 /// %n_next = arith.addi %n_iter, %c_1i32 : i32 /// scf.yield %arg_next, %n_next : i32, i32 /// } /// scf.yield %yield_val: i32, i32 /// } /// scf.yield %n_out : i32 /// } /// return %out: i32 /// } static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { … } /// Convert ctlz into a call to a local function implementing the ctlz /// operation. LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op, PatternRewriter &rewriter) const { … } namespace { struct ConvertMathToFuncsPass : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> { … }; } // namespace bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) { … } void ConvertMathToFuncsPass::generateOpImplementations() { … } void ConvertMathToFuncsPass::runOnOperation() { … }