llvm/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

//===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utilities used to lower to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"

#include <functional>
#include <optional>

#define DEBUG_TYPE

usingnamespacemlir;

namespace {

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {}

/// Checks that `candidates` extension requirements are possible to be satisfied
/// with the given `targetEnv`.
///
///  `candidates` is a vector of vector for extension requirements following
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
/// convention.
template <typename LabelT>
static LogicalResult checkExtensionRequirements(
    LabelT label, const spirv::TargetEnv &targetEnv,
    const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {}

/// Checks that `candidates`capability requirements are possible to be satisfied
/// with the given `isAllowedFn`.
///
///  `candidates` is a vector of vector for capability requirements following
/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
/// convention.
template <typename LabelT>
static LogicalResult checkCapabilityRequirements(
    LabelT label, const spirv::TargetEnv &targetEnv,
    const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {}

/// Returns true if the given `storageClass` needs explicit layout when used in
/// Shader environments.
static bool needsExplicitLayout(spirv::StorageClass storageClass) {}

/// Wraps the given `elementType` in a struct and gets the pointer to the
/// struct. This is used to satisfy Vulkan interface requirements.
static spirv::PointerType
wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {}

//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//

static spirv::ScalarType getIndexType(MLIRContext *ctx,
                                      const SPIRVConversionOptions &options) {}

// TODO: This is a utility function that should probably be exposed by the
// SPIR-V dialect. Keeping it local till the use case arises.
static std::optional<int64_t>
getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {}

/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
static Type
convertScalarType(const spirv::TargetEnv &targetEnv,
                  const SPIRVConversionOptions &options, spirv::ScalarType type,
                  std::optional<spirv::StorageClass> storageClass = {}

/// Converts a sub-byte integer `type` to i32 regardless of target environment.
///
/// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
/// the above given that these sub-byte types are not supported at all in
/// SPIR-V; there are no compute/storage capability for them like other
/// supported integer types.
static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
                                      IntegerType type) {}

/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
static ShapedType
convertIndexElementType(ShapedType type,
                        const SPIRVConversionOptions &options) {}

/// Converts a vector `type` to a suitable type under the given `targetEnv`.
static Type
convertVectorType(const spirv::TargetEnv &targetEnv,
                  const SPIRVConversionOptions &options, VectorType type,
                  std::optional<spirv::StorageClass> storageClass = {}

static Type
convertComplexType(const spirv::TargetEnv &targetEnv,
                   const SPIRVConversionOptions &options, ComplexType type,
                   std::optional<spirv::StorageClass> storageClass = {}

/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
///
/// Note that this is mainly for lowering constant tensors. In SPIR-V one can
/// create composite constants with OpConstantComposite to embed relative large
/// constant values and use OpCompositeExtract and OpCompositeInsert to
/// manipulate, like what we do for vectors.
static Type convertTensorType(const spirv::TargetEnv &targetEnv,
                              const SPIRVConversionOptions &options,
                              TensorType type) {}

static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
                                  const SPIRVConversionOptions &options,
                                  MemRefType type,
                                  spirv::StorageClass storageClass) {}

static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
                                     const SPIRVConversionOptions &options,
                                     MemRefType type,
                                     spirv::StorageClass storageClass) {}

static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
                              const SPIRVConversionOptions &options,
                              MemRefType type) {}

//===----------------------------------------------------------------------===//
// Type casting materialization
//===----------------------------------------------------------------------===//

/// Converts the given `inputs` to the original source `type` considering the
/// `targetEnv`'s capabilities.
///
/// This function is meant to be used for source materialization in type
/// converters. When the type converter needs to materialize a cast op back
/// to some original source type, we need to check whether the original source
/// type is supported in the target environment. If so, we can insert legal
/// SPIR-V cast ops accordingly.
///
/// Note that in SPIR-V the capabilities for storage and compute are separate.
/// This function is meant to handle the **compute** side; so it does not
/// involve storage classes in its logic. The storage side is expected to be
/// handled by MemRef conversion logic.
static std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
                                             OpBuilder &builder, Type type,
                                             ValueRange inputs, Location loc) {}

//===----------------------------------------------------------------------===//
// Builtin Variables
//===----------------------------------------------------------------------===//

static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
                                                  spirv::BuiltIn builtin) {}

/// Gets name of global variable for a builtin.
std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
                              StringRef suffix) {}

/// Gets or inserts a global variable for a builtin within `body` block.
static spirv::GlobalVariableOp
getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
                           Type integerType, OpBuilder &builder,
                           StringRef prefix, StringRef suffix) {}

//===----------------------------------------------------------------------===//
// Push constant storage
//===----------------------------------------------------------------------===//

/// Returns the pointer type for the push constant storage containing
/// `elementCount` 32-bit integer values.
static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
                                                     Builder &builder,
                                                     Type indexType) {}

/// Returns the push constant varible containing `elementCount` 32-bit integer
/// values in `body`. Returns null op if such an op does not exit.
static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
                                                       unsigned elementCount) {}

/// Gets or inserts a global variable for push constant storage containing
/// `elementCount` 32-bit integer values in `block`.
static spirv::GlobalVariableOp
getOrInsertPushConstantVariable(Location loc, Block &block,
                                unsigned elementCount, OpBuilder &b,
                                Type indexType) {}

//===----------------------------------------------------------------------===//
// func::FuncOp Conversion Patterns
//===----------------------------------------------------------------------===//

/// A pattern for rewriting function signature to convert arguments of functions
/// to be of valid SPIR-V types.
struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {};

/// A pattern for rewriting function signature to convert vector arguments of
/// functions to be of valid types
struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {};

//===----------------------------------------------------------------------===//
// func::ReturnOp Conversion Patterns
//===----------------------------------------------------------------------===//

/// A pattern for rewriting function signature and the return op to convert
/// vectors to be of valid types.
struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {};

} // namespace

//===----------------------------------------------------------------------===//
// Public function for builtin variables
//===----------------------------------------------------------------------===//

Value mlir::spirv::getBuiltinVariableValue(Operation *op,
                                           spirv::BuiltIn builtin,
                                           Type integerType, OpBuilder &builder,
                                           StringRef prefix, StringRef suffix) {}

//===----------------------------------------------------------------------===//
// Public function for pushing constant storage
//===----------------------------------------------------------------------===//

Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
                                  unsigned offset, Type integerType,
                                  OpBuilder &builder) {}

//===----------------------------------------------------------------------===//
// Public functions for index calculation
//===----------------------------------------------------------------------===//

Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
                                  int64_t offset, Type integerType,
                                  Location loc, OpBuilder &builder) {}

Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
                                       MemRefType baseType, Value basePtr,
                                       ValueRange indices, Location loc,
                                       OpBuilder &builder) {}

Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
                                       MemRefType baseType, Value basePtr,
                                       ValueRange indices, Location loc,
                                       OpBuilder &builder) {}

Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
                                 MemRefType baseType, Value basePtr,
                                 ValueRange indices, Location loc,
                                 OpBuilder &builder) {}

//===----------------------------------------------------------------------===//
// Public functions for vector unrolling
//===----------------------------------------------------------------------===//

int mlir::spirv::getComputeVectorSize(int64_t size) {}

SmallVector<int64_t>
mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {}

SmallVector<int64_t>
mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {}

std::optional<SmallVector<int64_t>>
mlir::spirv::getNativeVectorShape(Operation *op) {}

LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {}

LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {}

//===----------------------------------------------------------------------===//
// SPIR-V TypeConverter
//===----------------------------------------------------------------------===//

SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
                                       const SPIRVConversionOptions &options)
    :{}

Type SPIRVTypeConverter::getIndexType() const {}

MLIRContext *SPIRVTypeConverter::getContext() const {}

bool SPIRVTypeConverter::allows(spirv::Capability capability) const {}

//===----------------------------------------------------------------------===//
// SPIR-V ConversionTarget
//===----------------------------------------------------------------------===//

std::unique_ptr<SPIRVConversionTarget>
SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {}

SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
    :{}

bool SPIRVConversionTarget::isLegalOp(Operation *op) {}

//===----------------------------------------------------------------------===//
// Public functions for populating patterns
//===----------------------------------------------------------------------===//

void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                              RewritePatternSet &patterns) {}

void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {}

void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {}