llvm/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

//===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert SPIR-V dialect to LLVM dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"

#define DEBUG_TYPE

usingnamespacemlir;

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

/// Returns true if the given type is a signed integer or vector type.
static bool isSignedIntegerOrVector(Type type) {}

/// Returns true if the given type is an unsigned integer or vector type
static bool isUnsignedIntegerOrVector(Type type) {}

/// Returns the width of an integer or of the element type of an integer vector,
/// if applicable.
static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {}

/// Returns the bit width of integer, float or vector of float or integer values
static unsigned getBitWidth(Type type) {}

/// Returns the bit width of LLVMType integer or vector.
static unsigned getLLVMTypeBitWidth(Type type) {}

/// Creates `IntegerAttribute` with all bits set for given type
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {}

/// Creates `llvm.mlir.constant` with all bits set for the given type.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
                                      PatternRewriter &rewriter) {}

/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
static Value createFPConstant(Location loc, Type srcType, Type dstType,
                              PatternRewriter &rewriter, double value) {}

/// Utility function for bitfield ops:
///   - `BitFieldInsert`
///   - `BitFieldSExtract`
///   - `BitFieldUExtract`
/// Truncates or extends the value. If the bitwidth of the value is the same as
/// `llvmType` bitwidth, the value remains unchanged.
static Value optionallyTruncateOrExtend(Location loc, Value value,
                                        Type llvmType,
                                        PatternRewriter &rewriter) {}

/// Broadcasts the value to vector with `numElements` number of elements.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
                       const TypeConverter &typeConverter,
                       ConversionPatternRewriter &rewriter) {}

/// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
static Value optionallyBroadcast(Location loc, Value value, Type srcType,
                                 const TypeConverter &typeConverter,
                                 ConversionPatternRewriter &rewriter) {}

/// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
/// `BitFieldUExtract`.
/// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
/// a vector type, construct a vector that has:
///  - same number of elements as `Base`
///  - each element has the type that is the same as the type of `Offset` or
///    `Count`
///  - each element has the same value as `Offset` or `Count`
/// Then cast `Offset` and `Count` if their bit width is different
/// from `Base` bit width.
static Value processCountOrOffset(Location loc, Value value, Type srcType,
                                  Type dstType, const TypeConverter &converter,
                                  ConversionPatternRewriter &rewriter) {}

/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
/// offset to LLVM struct. Otherwise, the conversion is not supported.
static Type convertStructTypeWithOffset(spirv::StructType type,
                                        const TypeConverter &converter) {}

/// Converts SPIR-V struct with no offset to packed LLVM struct.
static Type convertStructTypePacked(spirv::StructType type,
                                    const TypeConverter &converter) {}

/// Creates LLVM dialect constant with the given value.
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
                                 unsigned value) {}

/// Utility for `spirv.Load` and `spirv.Store` conversion.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
                                            ConversionPatternRewriter &rewriter,
                                            const TypeConverter &typeConverter,
                                            unsigned alignment, bool isVolatile,
                                            bool isNonTemporal) {}

//===----------------------------------------------------------------------===//
// Type conversion
//===----------------------------------------------------------------------===//

/// Converts SPIR-V array type to LLVM array. Natural stride (according to
/// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
/// when converting ops that manipulate array types.
static std::optional<Type> convertArrayType(spirv::ArrayType type,
                                            TypeConverter &converter) {}

/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
/// modelled at the moment.
static Type convertPointerType(spirv::PointerType type,
                               const TypeConverter &converter,
                               spirv::ClientAPI clientAPI) {}

/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
/// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
/// no modelling of array stride at the moment.
static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
                                                   TypeConverter &converter) {}

/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
/// member decorations. Also, only natural offset is supported.
static Type convertStructType(spirv::StructType type,
                              const TypeConverter &converter) {}

//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//

namespace {

class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {};

class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {};

class BitFieldInsertPattern
    : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {};

/// Converts SPIR-V ConstantOp with scalar or vector type.
class ConstantScalarAndVectorPattern
    : public SPIRVToLLVMConversion<spirv::ConstantOp> {};

class BitFieldSExtractPattern
    : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {};

class BitFieldUExtractPattern
    : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {};

class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {};

class BranchConditionalConversionPattern
    : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {};

/// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
/// type is an aggregate type (struct or array). Otherwise, converts to
/// `llvm.extractelement` that operates on vectors.
class CompositeExtractPattern
    : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {};

/// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
/// type is an aggregate type (struct or array). Otherwise, converts to
/// `llvm.insertelement` that operates on vectors.
class CompositeInsertPattern
    : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {};

/// Converts SPIR-V operations that have straightforward LLVM equivalent
/// into LLVM dialect operations.
template <typename SPIRVOp, typename LLVMOp>
class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {};

/// Converts `spirv.ExecutionMode` into a global struct constant that holds
/// execution mode information.
class ExecutionModePattern
    : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {};

/// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
/// global returns a pointer, whereas in LLVM dialect the global holds an actual
/// value. This difference is handled by `spirv.mlir.addressof` and
/// `llvm.mlir.addressof`ops that both return a pointer.
class GlobalVariablePattern
    : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {};

/// Converts SPIR-V cast ops that do not have straightforward LLVM
/// equivalent in LLVM dialect.
template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {};

class FunctionCallPattern
    : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {};

/// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {};

/// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {};

class InverseSqrtPattern
    : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {};

/// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
template <typename SPIRVOp>
class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {};

/// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
template <typename SPIRVOp>
class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {};

/// A template pattern that erases the given `SPIRVOp`.
template <typename SPIRVOp>
class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {};

class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {};

class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {};

/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
/// should be reachable for conversion to succeed. The structure of the loop in
/// LLVM dialect will be the following:
///
///      +------------------------------------+
///      | <code before spirv.mlir.loop>        |
///      | llvm.br ^header                    |
///      +------------------------------------+
///                           |
///   +----------------+      |
///   |                |      |
///   |                V      V
///   |  +------------------------------------+
///   |  | ^header:                           |
///   |  |   <header code>                    |
///   |  |   llvm.cond_br %cond, ^body, ^exit |
///   |  +------------------------------------+
///   |                    |
///   |                    |----------------------+
///   |                    |                      |
///   |                    V                      |
///   |  +------------------------------------+   |
///   |  | ^body:                             |   |
///   |  |   <body code>                      |   |
///   |  |   llvm.br ^continue                |   |
///   |  +------------------------------------+   |
///   |                    |                      |
///   |                    V                      |
///   |  +------------------------------------+   |
///   |  | ^continue:                         |   |
///   |  |   <continue code>                  |   |
///   |  |   llvm.br ^header                  |   |
///   |  +------------------------------------+   |
///   |               |                           |
///   +---------------+    +----------------------+
///                        |
///                        V
///      +------------------------------------+
///      | ^exit:                             |
///      |   llvm.br ^remaining               |
///      +------------------------------------+
///                        |
///                        V
///      +------------------------------------+
///      | ^remaining:                        |
///      |   <code after spirv.mlir.loop>       |
///      +------------------------------------+
///
class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {};

/// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
/// block. All blocks within selection should be reachable for conversion to
/// succeed.
class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {};

/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
/// puts a restriction on `Shift` and `Base` to have the same bit width,
/// `Shift` is zero or sign extended to match this specification. Cases when
/// `Shift` bit width > `Base` bit width are considered to be illegal.
template <typename SPIRVOp, typename LLVMOp>
class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {};

class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {};

/// Convert `spirv.Tanh` to
///
///   exp(2x) - 1
///   -----------
///   exp(2x) + 1
///
class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {};

class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {};

//===----------------------------------------------------------------------===//
// BitcastOp conversion
//===----------------------------------------------------------------------===//

class BitcastConversionPattern
    : public SPIRVToLLVMConversion<spirv::BitcastOp> {};

//===----------------------------------------------------------------------===//
// FuncOp conversion
//===----------------------------------------------------------------------===//

class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {};

//===----------------------------------------------------------------------===//
// ModuleOp conversion
//===----------------------------------------------------------------------===//

class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {};

//===----------------------------------------------------------------------===//
// VectorShuffleOp conversion
//===----------------------------------------------------------------------===//

class VectorShufflePattern
    : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {};
} // namespace

//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//

void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter,
                                             spirv::ClientAPI clientAPI) {}

void mlir::populateSPIRVToLLVMConversionPatterns(
    const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
    spirv::ClientAPI clientAPI) {}

void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
    const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {}

void mlir::populateSPIRVToLLVMModuleConversionPatterns(
    const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {}

//===----------------------------------------------------------------------===//
// Pre-conversion hooks
//===----------------------------------------------------------------------===//

/// Hook for descriptor set and binding number encoding.
static constexpr StringRef kBinding =;
static constexpr StringRef kDescriptorSet =;
void mlir::encodeBindAttribute(ModuleOp module) {}