llvm/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

//===- PolynomialApproximation.cpp - Approximate math operations ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements expansion of math operations to fast approximations
// that do not rely on any of the library functions.
//
//===----------------------------------------------------------------------===//

#include <climits>
#include <cmath>
#include <cstddef>

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Approximation.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"

usingnamespacemlir;
usingnamespacemlir::math;
usingnamespacemlir::vector;

// Helper to encapsulate a vector's shape (including scalable dims).
struct VectorShape {};

// Returns vector shape if the type is a vector. Returns an empty shape if it is
// not a vector.
static VectorShape vectorShape(Type type) {}

static VectorShape vectorShape(Value value) {}

//----------------------------------------------------------------------------//
// Broadcast scalar types and values into vector types and values.
//----------------------------------------------------------------------------//

// Broadcasts scalar type into vector type (iff shape is non-scalar).
static Type broadcast(Type type, VectorShape shape) {}

// Broadcasts scalar value into vector (iff shape is non-scalar).
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
                       VectorShape shape) {}

//----------------------------------------------------------------------------//
// Helper function to handle n-D vectors with 1-D operations.
//----------------------------------------------------------------------------//

// Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
// and calls the compute function with 1-D vector operands. Stitches back all
// results into the original n-D vector result.
//
// Examples: vectorWidth = 8
//   - vector<4x8xf32> unrolled 4 times
//   - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
//   - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
//
// Some math approximations rely on ISA-specific operations that only accept
// fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
//
// It is the caller's responsibility to verify that the inner dimension is
// divisible by the vectorWidth, and that all operands have the same vector
// shape.
static Value
handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
                              ValueRange operands, int64_t vectorWidth,
                              llvm::function_ref<Value(ValueRange)> compute) {}

//----------------------------------------------------------------------------//
// Helper functions to create constants.
//----------------------------------------------------------------------------//

static Value floatCst(ImplicitLocOpBuilder &builder, float value,
                      Type elementType) {}

static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {}

static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {}

static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {}

//----------------------------------------------------------------------------//
// Helper functions to build math functions approximations.
//----------------------------------------------------------------------------//

// Return the minimum of the two values or NaN if value is NaN
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {}

// Return the maximum of the two values or NaN if value is NaN
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {}

// Return the clamped value or NaN if value is NaN
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
                   Value upperBound) {}

// Decomposes given floating point value `arg` into a normalized fraction and
// an integral power of two (see std::frexp). Returned values have float type.
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
                                     bool isPositive = false) {}

// Computes exp2 for an i32 argument.
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {}

namespace {
Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
                                llvm::ArrayRef<Value> coeffs, Value x) {}
} // namespace

//----------------------------------------------------------------------------//
// Helper function/pattern to insert casts for reusing F32 bit expansion.
//----------------------------------------------------------------------------//

template <typename T>
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {}

namespace {
// Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
// op.
// TODO: Consider revising to avoid adding multiple casts for a subgraph that is
// all in lower precision. Currently this is only fallback support and performs
// simplistic casting.
template <typename T>
struct ReuseF32Expansion : public OpRewritePattern<T> {};
} // namespace

//----------------------------------------------------------------------------//
// AtanOp approximation.
//----------------------------------------------------------------------------//

namespace {
struct AtanApproximation : public OpRewritePattern<math::AtanOp> {};
} // namespace

LogicalResult
AtanApproximation::matchAndRewrite(math::AtanOp op,
                                   PatternRewriter &rewriter) const {}

//----------------------------------------------------------------------------//
// AtanOp approximation.
//----------------------------------------------------------------------------//

namespace {
struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {};
} // namespace

LogicalResult
Atan2Approximation::matchAndRewrite(math::Atan2Op op,
                                    PatternRewriter &rewriter) const {}

//----------------------------------------------------------------------------//
// TanhOp approximation.
//----------------------------------------------------------------------------//

namespace {
struct TanhApproximation : public OpRewritePattern<math::TanhOp> {};
} // namespace

LogicalResult
TanhApproximation::matchAndRewrite(math::TanhOp op,
                                   PatternRewriter &rewriter) const {}

#define LN2_VALUE
#define LOG2E_VALUE

//----------------------------------------------------------------------------//
// LogOp and Log2Op approximation.
//----------------------------------------------------------------------------//

namespace {
template <typename Op>
struct LogApproximationBase : public OpRewritePattern<Op> {};
} // namespace

// This approximation comes from Julien Pommier's SSE math library.
// Link: http://gruntthepeon.free.fr/ssemath
template <typename Op>
LogicalResult
LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
                                             bool base2) const {}

namespace {
struct LogApproximation : public LogApproximationBase<math::LogOp> {};
} // namespace

namespace {
struct Log2Approximation : public LogApproximationBase<math::Log2Op> {};
} // namespace

//----------------------------------------------------------------------------//
// Log1p approximation.
//----------------------------------------------------------------------------//

namespace {
struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {};
} // namespace

// Approximate log(1+x).
LogicalResult
Log1pApproximation::matchAndRewrite(math::Log1pOp op,
                                    PatternRewriter &rewriter) const {}

//----------------------------------------------------------------------------//
// Asin approximation.
//----------------------------------------------------------------------------//

// Approximates asin(x).
// This approximation is based on the following stackoverflow post:
// https://stackoverflow.com/a/42683455
namespace {
struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {};
} // namespace
LogicalResult
AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
                                             PatternRewriter &rewriter) const {}

//----------------------------------------------------------------------------//
// Acos approximation.
//----------------------------------------------------------------------------//

// Approximates acos(x).
// This approximation is based on the following stackoverflow post:
// https://stackoverflow.com/a/42683455
namespace {
struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {};
} // namespace
LogicalResult
AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
                                             PatternRewriter &rewriter) const {}

//----------------------------------------------------------------------------//
// Erf approximation.
//----------------------------------------------------------------------------//

// Approximates erf(x) with
// a - P(x)/Q(x)
// where P and Q are polynomials of degree 4.
// Different coefficients are chosen based on the value of x.
// The approximation error is ~2.5e-07.
// Boost's minimax tool that utilizes the Remez method was used to find the
// coefficients.
LogicalResult
ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
                                            PatternRewriter &rewriter) const {}

//----------------------------------------------------------------------------//
// Exp approximation.
//----------------------------------------------------------------------------//

namespace {

Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
                       Value value, float lowerBound, float upperBound) {}

struct ExpApproximation : public OpRewritePattern<math::ExpOp> {};

LogicalResult
ExpApproximation::matchAndRewrite(math::ExpOp op,
                                  PatternRewriter &rewriter) const {}

} // namespace

//----------------------------------------------------------------------------//
// ExpM1 approximation.
//----------------------------------------------------------------------------//

namespace {

struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {};
} // namespace

LogicalResult
ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
                                    PatternRewriter &rewriter) const {}

//----------------------------------------------------------------------------//
// Sin and Cos approximation.
//----------------------------------------------------------------------------//

namespace {

template <bool isSine, typename OpTy>
struct SinAndCosApproximation : public OpRewritePattern<OpTy> {};
} // namespace

#define TWO_OVER_PI
#define PI_OVER_2

// Approximates sin(x) or cos(x) by finding the best approximation polynomial in
// the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
// reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
template <bool isSine, typename OpTy>
LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
    OpTy op, PatternRewriter &rewriter) const {}

//----------------------------------------------------------------------------//
// Cbrt approximation.
//----------------------------------------------------------------------------//

namespace {
struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {};
} // namespace

// Estimation of cube-root using an algorithm defined in
// Hacker's Delight 2nd Edition.
LogicalResult
CbrtApproximation::matchAndRewrite(math::CbrtOp op,
                                   PatternRewriter &rewriter) const {}

//----------------------------------------------------------------------------//
// Rsqrt approximation.
//----------------------------------------------------------------------------//

namespace {
struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {};
} // namespace

LogicalResult
RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
                                    PatternRewriter &rewriter) const {}

//----------------------------------------------------------------------------//

void mlir::populatePolynomialApproximateTanhPattern(
    RewritePatternSet &patterns) {}

void mlir::populatePolynomialApproximateErfPattern(
    RewritePatternSet &patterns) {}

void mlir::populateMathPolynomialApproximationPatterns(
    RewritePatternSet &patterns,
    const MathPolynomialApproximationOptions &options) {}