//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include <cassert> #include <memory> namespace mlir { #define GEN_PASS_DEF_CONVERTARITHTOSPIRV #include "mlir/Conversion/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE … usingnamespacemlir; //===----------------------------------------------------------------------===// // Conversion Helpers //===----------------------------------------------------------------------===// /// Converts the given `srcAttr` into a boolean attribute if it holds an /// integral value. Returns null attribute if conversion fails. static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { … } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. /// Returns null attribute if conversion fails. static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder) { … } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. /// Returns null attribute if `dstType` is not 32-bit or conversion fails. static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder) { … } /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { … } /// Creates a scalar/vector integer constant. static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc) { … } /// Returns true if scalar/vector type `a` and `b` have the same number of /// bitwidth. static bool hasSameBitwidth(Type a, Type b) { … } /// Returns a source type conversion failure for `srcType` and operation `op`. static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType) { … } /// Returns a source type conversion failure for the result type of `op`. static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) { … } // TODO: Move to some common place? static std::string getDecorationString(spirv::Decoration decor) { … } namespace { /// Converts elementwise unary, binary and ternary arith operations to SPIR-V /// operations. Op can potentially support overflow flags. template <typename Op, typename SPIRVOp> struct ElementwiseArithOpPattern final : OpConversionPattern<Op> { … }; //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// /// Converts composite arith.constant operation to spirv.Constant. struct ConstantCompositeOpPattern final : public OpConversionPattern<arith::ConstantOp> { … }; /// Converts scalar arith.constant operation to spirv.Constant. struct ConstantScalarOpPattern final : public OpConversionPattern<arith::ConstantOp> { … }; //===----------------------------------------------------------------------===// // RemSIOp //===----------------------------------------------------------------------===// /// Returns signed remainder for `lhs` and `rhs` and lets the result follow /// the sign of `signOperand`. /// /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod /// if either operand can be negative. Emulate it via spirv.UMod. template <typename SignedAbsOp> static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Value signOperand, OpBuilder &builder) { … } /// Converts arith.remsi to GLSL SPIR-V ops. /// /// This cannot be merged into the template unary/binary pattern due to Vulkan /// restrictions over spirv.SRem and spirv.SMod. struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> { … }; /// Converts arith.remsi to OpenCL SPIR-V ops. struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> { … }; //===----------------------------------------------------------------------===// // BitwiseOp //===----------------------------------------------------------------------===// /// Converts bitwise operations to SPIR-V operations. This is a special pattern /// other than the BinaryOpPatternPattern because if the operands are boolean /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> struct BitwiseOpPattern final : public OpConversionPattern<Op> { … }; //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// /// Converts arith.xori to SPIR-V operations. struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { … }; /// Converts arith.xori to SPIR-V operations if the type of source is i1 or /// vector of i1. struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { … }; //===----------------------------------------------------------------------===// // UIToFPOp //===----------------------------------------------------------------------===// /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector /// of i1. struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { … }; //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector /// of i1. struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> { … }; /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor /// vector of i1. struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> { … }; //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// /// Converts arith.extui to spirv.Select if the type of source is i1 or vector /// of i1. struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { … }; /// Converts arith.extui for cases where the type of source is neither i1 nor /// vector of i1. struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> { … }; //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector /// of i1. struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { … }; /// Converts arith.trunci for cases where the type of result is neither i1 /// nor vector of i1. struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> { … }; //===----------------------------------------------------------------------===// // TypeCastingOp //===----------------------------------------------------------------------===// static std::optional<spirv::FPRoundingMode> convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) { … } /// Converts type-casting standard operations to SPIR-V operations. template <typename Op, typename SPIRVOp> struct TypeCastingOpPattern final : public OpConversionPattern<Op> { … }; //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// /// Converts integer compare operation on i1 type operands to SPIR-V ops. class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { … }; /// Converts integer compare operation to SPIR-V ops. class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { … }; //===----------------------------------------------------------------------===// // CmpFOpPattern //===----------------------------------------------------------------------===// /// Converts floating-point comparison operations to SPIR-V ops. class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { … }; /// Converts floating point NaN check to SPIR-V ops. This pattern requires /// Kernel capability. class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { … }; /// Converts floating point NaN check to SPIR-V ops. This pattern does not /// require additional capability. class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { … }; //===----------------------------------------------------------------------===// // AddUIExtendedOp //===----------------------------------------------------------------------===// /// Converts arith.addui_extended to spirv.IAddCarry. class AddUIExtendedOpPattern final : public OpConversionPattern<arith::AddUIExtendedOp> { … }; //===----------------------------------------------------------------------===// // MulIExtendedOp //===----------------------------------------------------------------------===// /// Converts arith.mul*i_extended to spirv.*MulExtended. template <typename ArithMulOp, typename SPIRVMulOp> class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> { … }; //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// /// Converts arith.select to spirv.Select. class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> { … }; //===----------------------------------------------------------------------===// // MinimumFOp, MaximumFOp //===----------------------------------------------------------------------===// /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or /// spirv.CL.fmax/fmin. template <typename Op, typename SPIRVOp> class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> { … }; //===----------------------------------------------------------------------===// // MinNumFOp, MaxNumFOp //===----------------------------------------------------------------------===// /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or /// spirv.CL.fmax/fmin. template <typename Op, typename SPIRVOp> class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> { … }; } // namespace //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// void mlir::arith::populateArithToSPIRVPatterns( SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { … } //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// namespace { struct ConvertArithToSPIRVPass : public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> { … }; } // namespace std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() { … }