//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert MemRef dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Visitors.h" #include "llvm/Support/Debug.h" #include <cassert> #include <optional> #define DEBUG_TYPE … usingnamespacemlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Returns the offset of the value in `targetBits` representation. /// /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. /// It's assumed to be non-negative. /// /// When accessing an element in the array treating as having elements of /// `targetBits`, multiple values are loaded in the same time. The method /// returns the offset where the `srcIdx` locates in the value. For example, if /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is /// located at (x % 4) * 8. Because there are four elements in one i32, and one /// element has 8 bits. static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder) { … } /// Returns an adjusted spirv::AccessChainOp. Based on the /// extension/capabilities, certain integer bitwidths `sourceBits` might not be /// supported. During conversion if a memref of an unsupported type is used, /// load/stores to this memref need to be modified to use a supported higher /// bitwidth `targetBits` and extracting the required bits. For an accessing a /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load /// the bits needed. The extraction of the actual bits needed are handled /// separately. Note that this only works for a 1-D tensor. static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder) { … } /// Casts the given `srcBool` into an integer of `dstType`. static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder) { … } /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast /// to the type destination type, and masked. static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder) { … } /// Returns true if the allocations of memref `type` generated from `allocOp` /// can be lowered to SPIR-V. static bool isAllocationSupported(Operation *allocOp, MemRefType type) { … } /// Returns the scope to use for atomic operations use for emulating store /// operations of unsupported integer bitwidths, based on the memref /// type. Returns std::nullopt on failure. static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) { … } /// Casts the given `srcInt` into a boolean value. static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { … } //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// // Note that DRR cannot be used for the patterns in this file: we may need to // convert type along the way, which requires ConversionPattern. DRR generates // normal RewritePattern. namespace { /// Converts memref.alloca to SPIR-V Function variables. class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> { … }; /// Converts an allocation operation to SPIR-V. Currently only supports lowering /// to Workgroup memory when the size is constant. Note that this pattern needs /// to be applied in a pass that runs at least at spirv.module scope since it /// wil ladd global variables into the spirv.module. class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> { … }; /// Converts memref.automic_rmw operations to SPIR-V atomic operations. class AtomicRMWOpPattern final : public OpConversionPattern<memref::AtomicRMWOp> { … }; /// Removed a deallocation if it is a supported allocation. Currently only /// removes deallocation if the memory space is workgroup memory. class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> { … }; /// Converts memref.load to spirv.Load + spirv.AccessChain on integers. class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> { … }; /// Converts memref.load to spirv.Load + spirv.AccessChain. class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> { … }; /// Converts memref.store to spirv.Store on integers. class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> { … }; /// Converts memref.memory_space_cast to the appropriate spirv cast operations. class MemorySpaceCastOpPattern final : public OpConversionPattern<memref::MemorySpaceCastOp> { … }; /// Converts memref.store to spirv.Store. class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> { … }; class ReinterpretCastPattern final : public OpConversionPattern<memref::ReinterpretCastOp> { … }; class CastPattern final : public OpConversionPattern<memref::CastOp> { … }; } // namespace //===----------------------------------------------------------------------===// // AllocaOp //===----------------------------------------------------------------------===// LogicalResult AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { … } //===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// LogicalResult AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { … } //===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// LogicalResult AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { … } //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// LogicalResult DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { … } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// struct MemoryRequirements { … }; /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if /// any. static FailureOr<MemoryRequirements> calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { … } /// Given an accessed SPIR-V pointer and the original memref load/store /// `memAccess` op, calculates the alignment requirements, if any. Takes into /// account the alignment attributes applied to the load/store op. template <class LoadOrStoreOp> static FailureOr<MemoryRequirements> calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { … } LogicalResult IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { … } LogicalResult LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { … } LogicalResult IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { … } //===----------------------------------------------------------------------===// // MemorySpaceCastOp //===----------------------------------------------------------------------===// LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { … } LogicalResult StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { … } LogicalResult ReinterpretCastPattern::matchAndRewrite( memref::ReinterpretCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { … } //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// namespace mlir { void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { … } } // namespace mlir