//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert Math dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE … usingnamespacemlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the /// given type is not a 32-bit scalar/vector type. static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc) { … } /// Check if the type is supported by math-to-spirv conversion. We expect to /// only see scalars and vectors at this point, with higher-level types already /// lowered. static bool isSupportedSourceType(Type originalType) { … } /// Check if all `sourceOp` types are supported by math-to-spirv conversion. /// Notify of a match failure othwerise and return a `failure` result. /// This is intended to simplify type checks in `OpConversionPattern`s. static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp) { … } //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// // Note that DRR cannot be used for the patterns in this file: we may need to // convert type along the way, which requires ConversionPattern. DRR generates // normal RewritePattern. namespace { /// Converts elementwise unary, binary, and ternary standard operations to /// SPIR-V operations. Checks that source `Op` types are supported. template <typename Op, typename SPIRVOp> struct CheckedElementwiseOpPattern final : public spirv::ElementwiseOpPattern<Op, SPIRVOp> { … }; /// Converts math.copysign to SPIR-V ops. struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> { … }; /// Converts math.ctlz to SPIR-V ops. /// /// SPIR-V does not have a direct operations for counting leading zeros. If /// Shader capability is supported, we can leverage GL FindUMsb to calculate /// it. struct CountLeadingZerosPattern final : public OpConversionPattern<math::CountLeadingZerosOp> { … }; /// Converts math.expm1 to SPIR-V ops. /// /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to /// these operations. template <typename ExpOp> struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> { … }; /// Converts math.log1p to SPIR-V ops. /// /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to /// these operations. template <typename LogOp> struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { … }; /// Converts math.log2 and math.log10 to SPIR-V ops. /// /// SPIR-V does not have direct operations for log2 and log10. Explicitly /// lower to these operations using: /// log2(x) = log(x) * 1/log(2) /// log10(x) = log(x) * 1/log(10) template <typename MathLogOp, typename SpirvLogOp> struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> { … }; /// Converts math.powf to SPIRV-Ops. struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { … }; /// Converts math.round to GLSL SPIRV extended ops. struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// namespace mlir { void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { … } } // namespace mlir