//===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements utilities used to lower to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/OneToNTypeConversion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include <functional> #include <optional> #define DEBUG_TYPE … usingnamespacemlir; namespace { //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) { … } /// Checks that `candidates` extension requirements are possible to be satisfied /// with the given `targetEnv`. /// /// `candidates` is a vector of vector for extension requirements following /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) /// convention. template <typename LabelT> static LogicalResult checkExtensionRequirements( LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { … } /// Checks that `candidates`capability requirements are possible to be satisfied /// with the given `isAllowedFn`. /// /// `candidates` is a vector of vector for capability requirements following /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) /// convention. template <typename LabelT> static LogicalResult checkCapabilityRequirements( LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { … } /// Returns true if the given `storageClass` needs explicit layout when used in /// Shader environments. static bool needsExplicitLayout(spirv::StorageClass storageClass) { … } /// Wraps the given `elementType` in a struct and gets the pointer to the /// struct. This is used to satisfy Vulkan interface requirements. static spirv::PointerType wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) { … } //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===// static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options) { … } // TODO: This is a utility function that should probably be exposed by the // SPIR-V dialect. Keeping it local till the use case arises. static std::optional<int64_t> getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { … } /// Converts a scalar `type` to a suitable type under the given `targetEnv`. static Type convertScalarType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, spirv::ScalarType type, std::optional<spirv::StorageClass> storageClass = { … } /// Converts a sub-byte integer `type` to i32 regardless of target environment. /// /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use /// the above given that these sub-byte types are not supported at all in /// SPIR-V; there are no compute/storage capability for them like other /// supported integer types. static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, IntegerType type) { … } /// Returns a type with the same shape but with any index element type converted /// to the matching integer type. This is a noop when the element type is not /// the index type. static ShapedType convertIndexElementType(ShapedType type, const SPIRVConversionOptions &options) { … } /// Converts a vector `type` to a suitable type under the given `targetEnv`. static Type convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional<spirv::StorageClass> storageClass = { … } static Type convertComplexType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, ComplexType type, std::optional<spirv::StorageClass> storageClass = { … } /// Converts a tensor `type` to a suitable type under the given `targetEnv`. /// /// Note that this is mainly for lowering constant tensors. In SPIR-V one can /// create composite constants with OpConstantComposite to embed relative large /// constant values and use OpCompositeExtract and OpCompositeInsert to /// manipulate, like what we do for vectors. static Type convertTensorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, TensorType type) { … } static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type, spirv::StorageClass storageClass) { … } static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type, spirv::StorageClass storageClass) { … } static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type) { … } //===----------------------------------------------------------------------===// // Type casting materialization //===----------------------------------------------------------------------===// /// Converts the given `inputs` to the original source `type` considering the /// `targetEnv`'s capabilities. /// /// This function is meant to be used for source materialization in type /// converters. When the type converter needs to materialize a cast op back /// to some original source type, we need to check whether the original source /// type is supported in the target environment. If so, we can insert legal /// SPIR-V cast ops accordingly. /// /// Note that in SPIR-V the capabilities for storage and compute are separate. /// This function is meant to handle the **compute** side; so it does not /// involve storage classes in its logic. The storage side is expected to be /// handled by MemRef conversion logic. static std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv, OpBuilder &builder, Type type, ValueRange inputs, Location loc) { … } //===----------------------------------------------------------------------===// // Builtin Variables //===----------------------------------------------------------------------===// static spirv::GlobalVariableOp getBuiltinVariable(Block &body, spirv::BuiltIn builtin) { … } /// Gets name of global variable for a builtin. std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix, StringRef suffix) { … } /// Gets or inserts a global variable for a builtin within `body` block. static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix, StringRef suffix) { … } //===----------------------------------------------------------------------===// // Push constant storage //===----------------------------------------------------------------------===// /// Returns the pointer type for the push constant storage containing /// `elementCount` 32-bit integer values. static spirv::PointerType getPushConstantStorageType(unsigned elementCount, Builder &builder, Type indexType) { … } /// Returns the push constant varible containing `elementCount` 32-bit integer /// values in `body`. Returns null op if such an op does not exit. static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount) { … } /// Gets or inserts a global variable for push constant storage containing /// `elementCount` 32-bit integer values in `block`. static spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc, Block &block, unsigned elementCount, OpBuilder &b, Type indexType) { … } //===----------------------------------------------------------------------===// // func::FuncOp Conversion Patterns //===----------------------------------------------------------------------===// /// A pattern for rewriting function signature to convert arguments of functions /// to be of valid SPIR-V types. struct FuncOpConversion final : OpConversionPattern<func::FuncOp> { … }; /// A pattern for rewriting function signature to convert vector arguments of /// functions to be of valid types struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> { … }; //===----------------------------------------------------------------------===// // func::ReturnOp Conversion Patterns //===----------------------------------------------------------------------===// /// A pattern for rewriting function signature and the return op to convert /// vectors to be of valid types. struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> { … }; } // namespace //===----------------------------------------------------------------------===// // Public function for builtin variables //===----------------------------------------------------------------------===// Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix, StringRef suffix) { … } //===----------------------------------------------------------------------===// // Public function for pushing constant storage //===----------------------------------------------------------------------===// Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder) { … } //===----------------------------------------------------------------------===// // Public functions for index calculation //===----------------------------------------------------------------------===// Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder) { … } Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder) { … } Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder) { … } Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder) { … } //===----------------------------------------------------------------------===// // Public functions for vector unrolling //===----------------------------------------------------------------------===// int mlir::spirv::getComputeVectorSize(int64_t size) { … } SmallVector<int64_t> mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) { … } SmallVector<int64_t> mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) { … } std::optional<SmallVector<int64_t>> mlir::spirv::getNativeVectorShape(Operation *op) { … } LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { … } LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { … } //===----------------------------------------------------------------------===// // SPIR-V TypeConverter //===----------------------------------------------------------------------===// SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options) : … { … } Type SPIRVTypeConverter::getIndexType() const { … } MLIRContext *SPIRVTypeConverter::getContext() const { … } bool SPIRVTypeConverter::allows(spirv::Capability capability) const { … } //===----------------------------------------------------------------------===// // SPIR-V ConversionTarget //===----------------------------------------------------------------------===// std::unique_ptr<SPIRVConversionTarget> SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { … } SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr) : … { … } bool SPIRVConversionTarget::isLegalOp(Operation *op) { … } //===----------------------------------------------------------------------===// // Public functions for populating patterns //===----------------------------------------------------------------------===// void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { … } void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) { … } void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) { … }