llvm/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp

//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

usingnamespacemlir;

#define DEBUG_TYPE
#define DBGS()

namespace {
// Pattern to convert vector operations to scalar operations.
template <typename Op>
struct VecOpToScalarOp : public OpRewritePattern<Op> {};

// Callback type for getting pre-generated FuncOp implementing
// an operation of the given type.
GetFuncCallbackTy;

// Pattern to convert scalar IPowIOp into a call of outlined
// software implementation.
class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {};

// Pattern to convert scalar FPowIOp into a call of outlined
// software implementation.
class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {};

// Pattern to convert scalar ctlz into a call of outlined software
// implementation.
class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> {};
} // namespace

template <typename Op>
LogicalResult
VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {}

static FunctionType getElementalFuncTypeForOp(Operation *op) {}

/// Create linkonce_odr function to implement the power function with
/// the given \p elementType type inside \p module. The \p elementType
/// must be IntegerType, an the created function has
/// 'IntegerType (*)(IntegerType, IntegerType)' function type.
///
/// template <typename T>
/// T __mlir_math_ipowi_*(T b, T p) {
///   if (p == T(0))
///     return T(1);
///   if (p < T(0)) {
///     if (b == T(0))
///       return T(1) / T(0); // trigger div-by-zero
///     if (b == T(1))
///       return T(1);
///     if (b == T(-1)) {
///       if (p & T(1))
///         return T(-1);
///       return T(1);
///     }
///     return T(0);
///   }
///   T result = T(1);
///   while (true) {
///     if (p & T(1))
///       result *= b;
///     p >>= T(1);
///     if (p == T(0))
///       return result;
///     b *= b;
///   }
/// }
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {}

/// Convert IPowI into a call to a local function implementing
/// the power operation. The local function computes a scalar result,
/// so vector forms of IPowI are linearized.
LogicalResult
IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
                                 PatternRewriter &rewriter) const {}

/// Create linkonce_odr function to implement the power function with
/// the given \p funcType type inside \p module. The \p funcType must be
/// 'FloatType (*)(FloatType, IntegerType)' function type.
///
/// template <typename T>
/// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
///   if (p == Tp{0})
///     return Tb{1};
///   bool isNegativePower{p < Tp{0}}
///   bool isMin{p == std::numeric_limits<Tp>::min()};
///   if (isMin) {
///     p = std::numeric_limits<Tp>::max();
///   } else if (isNegativePower) {
///     p = -p;
///   }
///   Tb result = Tb{1};
///   Tb origBase = Tb{b};
///   while (true) {
///     if (p & Tp{1})
///       result *= b;
///     p >>= Tp{1};
///     if (p == Tp{0})
///       break;
///     b *= b;
///   }
///   if (isMin) {
///     result *= origBase;
///   }
///   if (isNegativePower) {
///     result = Tb{1} / result;
///   }
///   return result;
/// }
static func::FuncOp createElementFPowIFunc(ModuleOp *module,
                                           FunctionType funcType) {}

/// Convert FPowI into a call to a local function implementing
/// the power operation. The local function computes a scalar result,
/// so vector forms of FPowI are linearized.
LogicalResult
FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
                                 PatternRewriter &rewriter) const {}

/// Create function to implement the ctlz function the given \p elementType type
/// inside \p module. The \p elementType must be IntegerType, an the created
/// function has 'IntegerType (*)(IntegerType)' function type.
///
/// template <typename T>
/// T __mlir_math_ctlz_*(T x) {
///     bits = sizeof(x) * 8;
///     if (x == 0)
///       return bits;
///
///     uint32_t n = 0;
///     for (int i = 1; i < bits; ++i) {
///         if (x < 0) continue;
///         n++;
///         x <<= 1;
///     }
///     return n;
/// }
///
/// Converts to (for i32):
///
/// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
///   %c_32 = arith.constant 32 : index
///   %c_0 = arith.constant 0 : i32
///   %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
///   %out = scf.if %arg_eq_zero {
///     scf.yield %c_32 : i32
///   } else {
///     %c_1index = arith.constant 1 : index
///     %c_1i32 = arith.constant 1 : i32
///     %n = arith.constant 0 : i32
///     %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
///         iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
///       %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
///       %yield_val = scf.if %cond {
///         scf.yield %arg_iter, %n_iter : i32, i32
///       } else {
///         %arg_next = arith.shli %arg_iter, %c_1i32 : i32
///         %n_next = arith.addi %n_iter, %c_1i32 : i32
///         scf.yield %arg_next, %n_next : i32, i32
///       }
///       scf.yield %yield_val: i32, i32
///     }
///     scf.yield %n_out : i32
///   }
///   return %out: i32
/// }
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {}

/// Convert ctlz into a call to a local function implementing the ctlz
/// operation.
LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
                                              PatternRewriter &rewriter) const {}

namespace {
struct ConvertMathToFuncsPass
    : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {};
} // namespace

bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {}

void ConvertMathToFuncsPass::generateOpImplementations() {}

void ConvertMathToFuncsPass::runOnOperation() {}