llvm/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Visitors.h"
#include "llvm/Support/Debug.h"
#include <cassert>
#include <optional>

#define DEBUG_TYPE

usingnamespacemlir;

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

/// Returns the offset of the value in `targetBits` representation.
///
/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
/// It's assumed to be non-negative.
///
/// When accessing an element in the array treating as having elements of
/// `targetBits`, multiple values are loaded in the same time. The method
/// returns the offset where the `srcIdx` locates in the value. For example, if
/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
/// located at (x % 4) * 8. Because there are four elements in one i32, and one
/// element has 8 bits.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
                                  int targetBits, OpBuilder &builder) {}

/// Returns an adjusted spirv::AccessChainOp. Based on the
/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
/// supported. During conversion if a memref of an unsupported type is used,
/// load/stores to this memref need to be modified to use a supported higher
/// bitwidth `targetBits` and extracting the required bits. For an accessing a
/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
/// the bits needed. The extraction of the actual bits needed are handled
/// separately. Note that this only works for a 1-D tensor.
static Value
adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
                             spirv::AccessChainOp op, int sourceBits,
                             int targetBits, OpBuilder &builder) {}

/// Casts the given `srcBool` into an integer of `dstType`.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
                            OpBuilder &builder) {}

/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
/// to the type destination type, and masked.
static Value shiftValue(Location loc, Value value, Value offset, Value mask,
                        OpBuilder &builder) {}

/// Returns true if the allocations of memref `type` generated from `allocOp`
/// can be lowered to SPIR-V.
static bool isAllocationSupported(Operation *allocOp, MemRefType type) {}

/// Returns the scope to use for atomic operations use for emulating store
/// operations of unsupported integer bitwidths, based on the memref
/// type. Returns std::nullopt on failure.
static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {}

/// Casts the given `srcInt` into a boolean value.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {}

//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//

// Note that DRR cannot be used for the patterns in this file: we may need to
// convert type along the way, which requires ConversionPattern. DRR generates
// normal RewritePattern.

namespace {

/// Converts memref.alloca to SPIR-V Function variables.
class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {};

/// Converts an allocation operation to SPIR-V. Currently only supports lowering
/// to Workgroup memory when the size is constant.  Note that this pattern needs
/// to be applied in a pass that runs at least at spirv.module scope since it
/// wil ladd global variables into the spirv.module.
class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {};

/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
class AtomicRMWOpPattern final
    : public OpConversionPattern<memref::AtomicRMWOp> {};

/// Removed a deallocation if it is a supported allocation. Currently only
/// removes deallocation if the memory space is workgroup memory.
class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {};

/// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {};

/// Converts memref.load to spirv.Load + spirv.AccessChain.
class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {};

/// Converts memref.store to spirv.Store on integers.
class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {};

/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
class MemorySpaceCastOpPattern final
    : public OpConversionPattern<memref::MemorySpaceCastOp> {};

/// Converts memref.store to spirv.Store.
class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {};

class ReinterpretCastPattern final
    : public OpConversionPattern<memref::ReinterpretCastOp> {};

class CastPattern final : public OpConversionPattern<memref::CastOp> {};

} // namespace

//===----------------------------------------------------------------------===//
// AllocaOp
//===----------------------------------------------------------------------===//

LogicalResult
AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {}

//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//

LogicalResult
AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {}

//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//

LogicalResult
AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
                                    OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {}

//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//

LogicalResult
DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
                                  OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {}

//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//

struct MemoryRequirements {};

/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
/// any.
static FailureOr<MemoryRequirements>
calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {}

/// Given an accessed SPIR-V pointer and the original memref load/store
/// `memAccess` op, calculates the alignment requirements, if any. Takes into
/// account the alignment attributes applied to the load/store op.
template <class LoadOrStoreOp>
static FailureOr<MemoryRequirements>
calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {}

LogicalResult
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {}

LogicalResult
LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                               ConversionPatternRewriter &rewriter) const {}

LogicalResult
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {}

//===----------------------------------------------------------------------===//
// MemorySpaceCastOp
//===----------------------------------------------------------------------===//

LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
    memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {}

LogicalResult
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {}

LogicalResult ReinterpretCastPattern::matchAndRewrite(
    memref::ReinterpretCastOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {}

//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//

namespace mlir {
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                                   RewritePatternSet &patterns) {}
} // namespace mlir