llvm/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

//===- SPIRVConversion.h - SPIR-V Conversion Utilities ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines utilities to use while converting to the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
#define MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H

#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/LogicalResult.h"

namespace mlir {

//===----------------------------------------------------------------------===//
// Type Converter
//===----------------------------------------------------------------------===//

/// How sub-byte values are storaged in memory.
enum class SPIRVSubByteTypeStorage {};

struct SPIRVConversionOptions {};

/// Type conversion from builtin types to SPIR-V types for shader interface.
///
/// For memref types, this converter additionally performs type wrapping to
/// satisfy shader interface requirements: shader interface types must be
/// pointers to structs.
class SPIRVTypeConverter : public TypeConverter {};

//===----------------------------------------------------------------------===//
// Conversion Target
//===----------------------------------------------------------------------===//

// The default SPIR-V conversion target.
//
// It takes a SPIR-V target environment and controls operation legality based on
// the their availability in the target environment.
class SPIRVConversionTarget : public ConversionTarget {};

//===----------------------------------------------------------------------===//
// Patterns and Utility Functions
//===----------------------------------------------------------------------===//

/// Appends to a pattern list additional patterns for translating the builtin
/// `func` op to the SPIR-V dialect. These patterns do not handle shader
/// interface/ABI; they convert function parameters to be of SPIR-V allowed
/// types.
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                        RewritePatternSet &patterns);

void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);

void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns);

namespace spirv {
class AccessChainOp;

/// Returns the value for the given `builtin` variable. This function gets or
/// inserts the global variable associated for the builtin within the nearest
/// symbol table enclosing `op`. Returns null Value on error.
///
/// The global name being generated will be mangled using `preffix` and
/// `suffix`.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType,
                              OpBuilder &builder,
                              StringRef prefix = "__builtin__",
                              StringRef suffix = "__");

/// Gets the value at the given `offset` of the push constant storage with a
/// total of `elementCount` `integerType` integers. A global variable will be
/// created in the nearest symbol table enclosing `op` for the push constant
/// storage if not existing. Load ops will be created via the given `builder` to
/// load values from the push constant. Returns null Value on error.
Value getPushConstantValue(Operation *op, unsigned elementCount,
                           unsigned offset, Type integerType,
                           OpBuilder &builder);

/// Generates IR to perform index linearization with the given `indices` and
/// their corresponding `strides`, adding an initial `offset`.
Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
                     int64_t offset, Type integerType, Location loc,
                     OpBuilder &builder);

/// Performs the index computation to get to the element at `indices` of the
/// memory pointed to by `basePtr`, using the layout map of `baseType`.
/// Returns null if index computation cannot be performed.

// TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
// that has static strides. Extend to handle dynamic strides.
Value getElementPtr(const SPIRVTypeConverter &typeConverter,
                    MemRefType baseType, Value basePtr, ValueRange indices,
                    Location loc, OpBuilder &builder);

// GetElementPtr implementation for Kernel/OpenCL flavored SPIR-V.
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
                          MemRefType baseType, Value basePtr,
                          ValueRange indices, Location loc, OpBuilder &builder);

// GetElementPtr implementation for Vulkan/Shader flavored SPIR-V.
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
                          MemRefType baseType, Value basePtr,
                          ValueRange indices, Location loc, OpBuilder &builder);

// Find the largest factor of size among {2,3,4} for the lowest dimension of
// the target shape.
int getComputeVectorSize(int64_t size);

// GetNativeVectorShape implementation for reduction ops.
SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);

// GetNativeVectorShape implementation for transpose ops.
SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op);

// For general ops.
std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op);

// Unroll vectors in function signatures to native size.
LogicalResult unrollVectorsInSignatures(Operation *op);

// Unroll vectors in function bodies to native size.
LogicalResult unrollVectorsInFuncBodies(Operation *op);

} // namespace spirv
} // namespace mlir

#endif // MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H