//===- PolynomialApproximation.cpp - Approximate math operations ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements expansion of math operations to fast approximations // that do not rely on any of the library functions. // //===----------------------------------------------------------------------===// #include <climits> #include <cmath> #include <cstddef> #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Approximation.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" usingnamespacemlir; usingnamespacemlir::math; usingnamespacemlir::vector; // Helper to encapsulate a vector's shape (including scalable dims). struct VectorShape { … }; // Returns vector shape if the type is a vector. Returns an empty shape if it is // not a vector. static VectorShape vectorShape(Type type) { … } static VectorShape vectorShape(Value value) { … } //----------------------------------------------------------------------------// // Broadcast scalar types and values into vector types and values. //----------------------------------------------------------------------------// // Broadcasts scalar type into vector type (iff shape is non-scalar). static Type broadcast(Type type, VectorShape shape) { … } // Broadcasts scalar value into vector (iff shape is non-scalar). static Value broadcast(ImplicitLocOpBuilder &builder, Value value, VectorShape shape) { … } //----------------------------------------------------------------------------// // Helper function to handle n-D vectors with 1-D operations. //----------------------------------------------------------------------------// // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors // and calls the compute function with 1-D vector operands. Stitches back all // results into the original n-D vector result. // // Examples: vectorWidth = 8 // - vector<4x8xf32> unrolled 4 times // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times // // Some math approximations rely on ISA-specific operations that only accept // fixed size 1-D vectors (e.g. AVX expects vectors of width 8). // // It is the caller's responsibility to verify that the inner dimension is // divisible by the vectorWidth, and that all operands have the same vector // shape. static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref<Value(ValueRange)> compute) { … } //----------------------------------------------------------------------------// // Helper functions to create constants. //----------------------------------------------------------------------------// static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType) { … } static Value f32Cst(ImplicitLocOpBuilder &builder, double value) { … } static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { … } static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { … } //----------------------------------------------------------------------------// // Helper functions to build math functions approximations. //----------------------------------------------------------------------------// // Return the minimum of the two values or NaN if value is NaN static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) { … } // Return the maximum of the two values or NaN if value is NaN static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) { … } // Return the clamped value or NaN if value is NaN static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound) { … } // Decomposes given floating point value `arg` into a normalized fraction and // an integral power of two (see std::frexp). Returned values have float type. static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive = false) { … } // Computes exp2 for an i32 argument. static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { … } namespace { Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, llvm::ArrayRef<Value> coeffs, Value x) { … } } // namespace //----------------------------------------------------------------------------// // Helper function/pattern to insert casts for reusing F32 bit expansion. //----------------------------------------------------------------------------// template <typename T> LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) { … } namespace { // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result // op. // TODO: Consider revising to avoid adding multiple casts for a subgraph that is // all in lower precision. Currently this is only fallback support and performs // simplistic casting. template <typename T> struct ReuseF32Expansion : public OpRewritePattern<T> { … }; } // namespace //----------------------------------------------------------------------------// // AtanOp approximation. //----------------------------------------------------------------------------// namespace { struct AtanApproximation : public OpRewritePattern<math::AtanOp> { … }; } // namespace LogicalResult AtanApproximation::matchAndRewrite(math::AtanOp op, PatternRewriter &rewriter) const { … } //----------------------------------------------------------------------------// // AtanOp approximation. //----------------------------------------------------------------------------// namespace { struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> { … }; } // namespace LogicalResult Atan2Approximation::matchAndRewrite(math::Atan2Op op, PatternRewriter &rewriter) const { … } //----------------------------------------------------------------------------// // TanhOp approximation. //----------------------------------------------------------------------------// namespace { struct TanhApproximation : public OpRewritePattern<math::TanhOp> { … }; } // namespace LogicalResult TanhApproximation::matchAndRewrite(math::TanhOp op, PatternRewriter &rewriter) const { … } #define LN2_VALUE … #define LOG2E_VALUE … //----------------------------------------------------------------------------// // LogOp and Log2Op approximation. //----------------------------------------------------------------------------// namespace { template <typename Op> struct LogApproximationBase : public OpRewritePattern<Op> { … }; } // namespace // This approximation comes from Julien Pommier's SSE math library. // Link: http://gruntthepeon.free.fr/ssemath template <typename Op> LogicalResult LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, bool base2) const { … } namespace { struct LogApproximation : public LogApproximationBase<math::LogOp> { … }; } // namespace namespace { struct Log2Approximation : public LogApproximationBase<math::Log2Op> { … }; } // namespace //----------------------------------------------------------------------------// // Log1p approximation. //----------------------------------------------------------------------------// namespace { struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { … }; } // namespace // Approximate log(1+x). LogicalResult Log1pApproximation::matchAndRewrite(math::Log1pOp op, PatternRewriter &rewriter) const { … } //----------------------------------------------------------------------------// // Asin approximation. //----------------------------------------------------------------------------// // Approximates asin(x). // This approximation is based on the following stackoverflow post: // https://stackoverflow.com/a/42683455 namespace { struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> { … }; } // namespace LogicalResult AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op, PatternRewriter &rewriter) const { … } //----------------------------------------------------------------------------// // Acos approximation. //----------------------------------------------------------------------------// // Approximates acos(x). // This approximation is based on the following stackoverflow post: // https://stackoverflow.com/a/42683455 namespace { struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> { … }; } // namespace LogicalResult AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op, PatternRewriter &rewriter) const { … } //----------------------------------------------------------------------------// // Erf approximation. //----------------------------------------------------------------------------// // Approximates erf(x) with // a - P(x)/Q(x) // where P and Q are polynomials of degree 4. // Different coefficients are chosen based on the value of x. // The approximation error is ~2.5e-07. // Boost's minimax tool that utilizes the Remez method was used to find the // coefficients. LogicalResult ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const { … } //----------------------------------------------------------------------------// // Exp approximation. //----------------------------------------------------------------------------// namespace { Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape, Value value, float lowerBound, float upperBound) { … } struct ExpApproximation : public OpRewritePattern<math::ExpOp> { … }; LogicalResult ExpApproximation::matchAndRewrite(math::ExpOp op, PatternRewriter &rewriter) const { … } } // namespace //----------------------------------------------------------------------------// // ExpM1 approximation. //----------------------------------------------------------------------------// namespace { struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { … }; } // namespace LogicalResult ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, PatternRewriter &rewriter) const { … } //----------------------------------------------------------------------------// // Sin and Cos approximation. //----------------------------------------------------------------------------// namespace { template <bool isSine, typename OpTy> struct SinAndCosApproximation : public OpRewritePattern<OpTy> { … }; } // namespace #define TWO_OVER_PI … #define PI_OVER_2 … // Approximates sin(x) or cos(x) by finding the best approximation polynomial in // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). template <bool isSine, typename OpTy> LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( OpTy op, PatternRewriter &rewriter) const { … } //----------------------------------------------------------------------------// // Cbrt approximation. //----------------------------------------------------------------------------// namespace { struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> { … }; } // namespace // Estimation of cube-root using an algorithm defined in // Hacker's Delight 2nd Edition. LogicalResult CbrtApproximation::matchAndRewrite(math::CbrtOp op, PatternRewriter &rewriter) const { … } //----------------------------------------------------------------------------// // Rsqrt approximation. //----------------------------------------------------------------------------// namespace { struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { … }; } // namespace LogicalResult RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, PatternRewriter &rewriter) const { … } //----------------------------------------------------------------------------// void mlir::populatePolynomialApproximateTanhPattern( RewritePatternSet &patterns) { … } void mlir::populatePolynomialApproximateErfPattern( RewritePatternSet &patterns) { … } void mlir::populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options) { … }