//===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert SPIR-V dialect to LLVM dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE … usingnamespacemlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Returns true if the given type is a signed integer or vector type. static bool isSignedIntegerOrVector(Type type) { … } /// Returns true if the given type is an unsigned integer or vector type static bool isUnsignedIntegerOrVector(Type type) { … } /// Returns the width of an integer or of the element type of an integer vector, /// if applicable. static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) { … } /// Returns the bit width of integer, float or vector of float or integer values static unsigned getBitWidth(Type type) { … } /// Returns the bit width of LLVMType integer or vector. static unsigned getLLVMTypeBitWidth(Type type) { … } /// Creates `IntegerAttribute` with all bits set for given type static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { … } /// Creates `llvm.mlir.constant` with all bits set for the given type. static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { … } /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value) { … } /// Utility function for bitfield ops: /// - `BitFieldInsert` /// - `BitFieldSExtract` /// - `BitFieldUExtract` /// Truncates or extends the value. If the bitwidth of the value is the same as /// `llvmType` bitwidth, the value remains unchanged. static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter) { … } /// Broadcasts the value to vector with `numElements` number of elements. static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { … } /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { … } /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and /// `BitFieldUExtract`. /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of /// a vector type, construct a vector that has: /// - same number of elements as `Base` /// - each element has the type that is the same as the type of `Offset` or /// `Count` /// - each element has the same value as `Offset` or `Count` /// Then cast `Offset` and `Count` if their bit width is different /// from `Base` bit width. static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter) { … } /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) /// offset to LLVM struct. Otherwise, the conversion is not supported. static Type convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter) { … } /// Converts SPIR-V struct with no offset to packed LLVM struct. static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter) { … } /// Creates LLVM dialect constant with the given value. static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value) { … } /// Utility for `spirv.Load` and `spirv.Store` conversion. static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal) { … } //===----------------------------------------------------------------------===// // Type conversion //===----------------------------------------------------------------------===// /// Converts SPIR-V array type to LLVM array. Natural stride (according to /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected /// when converting ops that manipulate array types. static std::optional<Type> convertArrayType(spirv::ArrayType type, TypeConverter &converter) { … } /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not /// modelled at the moment. static Type convertPointerType(spirv::PointerType type, LLVMTypeConverter &converter, spirv::ClientAPI clientAPI) { … } /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is /// no modelling of array stride at the moment. static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter) { … } /// Converts SPIR-V struct to LLVM struct. There is no support of structs with /// member decorations. Also, only natural offset is supported. static Type convertStructType(spirv::StructType type, LLVMTypeConverter &converter) { … } //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// namespace { class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> { … }; class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> { … }; class BitFieldInsertPattern : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> { … }; /// Converts SPIR-V ConstantOp with scalar or vector type. class ConstantScalarAndVectorPattern : public SPIRVToLLVMConversion<spirv::ConstantOp> { … }; class BitFieldSExtractPattern : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> { … }; class BitFieldUExtractPattern : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> { … }; class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> { … }; class BranchConditionalConversionPattern : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> { … }; /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container /// type is an aggregate type (struct or array). Otherwise, converts to /// `llvm.extractelement` that operates on vectors. class CompositeExtractPattern : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> { … }; /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container /// type is an aggregate type (struct or array). Otherwise, converts to /// `llvm.insertelement` that operates on vectors. class CompositeInsertPattern : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> { … }; /// Converts SPIR-V operations that have straightforward LLVM equivalent /// into LLVM dialect operations. template <typename SPIRVOp, typename LLVMOp> class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> { … }; /// Converts `spirv.ExecutionMode` into a global struct constant that holds /// execution mode information. class ExecutionModePattern : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> { … }; /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V /// global returns a pointer, whereas in LLVM dialect the global holds an actual /// value. This difference is handled by `spirv.mlir.addressof` and /// `llvm.mlir.addressof`ops that both return a pointer. class GlobalVariablePattern : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> { … }; /// Converts SPIR-V cast ops that do not have straightforward LLVM /// equivalent in LLVM dialect. template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp> class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> { … }; class FunctionCallPattern : public SPIRVToLLVMConversion<spirv::FunctionCallOp> { … }; /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" template <typename SPIRVOp, LLVM::FCmpPredicate predicate> class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { … }; /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" template <typename SPIRVOp, LLVM::ICmpPredicate predicate> class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { … }; class InverseSqrtPattern : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> { … }; /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect. template <typename SPIRVOp> class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> { … }; /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect. template <typename SPIRVOp> class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> { … }; /// A template pattern that erases the given `SPIRVOp`. template <typename SPIRVOp> class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> { … }; class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> { … }; class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> { … }; /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection /// should be reachable for conversion to succeed. The structure of the loop in /// LLVM dialect will be the following: /// /// +------------------------------------+ /// | <code before spirv.mlir.loop> | /// | llvm.br ^header | /// +------------------------------------+ /// | /// +----------------+ | /// | | | /// | V V /// | +------------------------------------+ /// | | ^header: | /// | | <header code> | /// | | llvm.cond_br %cond, ^body, ^exit | /// | +------------------------------------+ /// | | /// | |----------------------+ /// | | | /// | V | /// | +------------------------------------+ | /// | | ^body: | | /// | | <body code> | | /// | | llvm.br ^continue | | /// | +------------------------------------+ | /// | | | /// | V | /// | +------------------------------------+ | /// | | ^continue: | | /// | | <continue code> | | /// | | llvm.br ^header | | /// | +------------------------------------+ | /// | | | /// +---------------+ +----------------------+ /// | /// V /// +------------------------------------+ /// | ^exit: | /// | llvm.br ^remaining | /// +------------------------------------+ /// | /// V /// +------------------------------------+ /// | ^remaining: | /// | <code after spirv.mlir.loop> | /// +------------------------------------+ /// class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> { … }; /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header /// block. All blocks within selection should be reachable for conversion to /// succeed. class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> { … }; /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect /// puts a restriction on `Shift` and `Base` to have the same bit width, /// `Shift` is zero or sign extended to match this specification. Cases when /// `Shift` bit width > `Base` bit width are considered to be illegal. template <typename SPIRVOp, typename LLVMOp> class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> { … }; class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> { … }; /// Convert `spirv.Tanh` to /// /// exp(2x) - 1 /// ----------- /// exp(2x) + 1 /// class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> { … }; class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { … }; //===----------------------------------------------------------------------===// // BitcastOp conversion //===----------------------------------------------------------------------===// class BitcastConversionPattern : public SPIRVToLLVMConversion<spirv::BitcastOp> { … }; //===----------------------------------------------------------------------===// // FuncOp conversion //===----------------------------------------------------------------------===// class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> { … }; //===----------------------------------------------------------------------===// // ModuleOp conversion //===----------------------------------------------------------------------===// class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> { … }; //===----------------------------------------------------------------------===// // VectorShuffleOp conversion //===----------------------------------------------------------------------===// class VectorShufflePattern : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPI) { … } void mlir::populateSPIRVToLLVMConversionPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPI) { … } void mlir::populateSPIRVToLLVMFunctionConversionPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { … } void mlir::populateSPIRVToLLVMModuleConversionPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { … } //===----------------------------------------------------------------------===// // Pre-conversion hooks //===----------------------------------------------------------------------===// /// Hook for descriptor set and binding number encoding. static constexpr StringRef kBinding = …; static constexpr StringRef kDescriptorSet = …; void mlir::encodeBindAttribute(ModuleOp module) { … }