//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements lowering of ArmSME operations to LLVM intrinsics. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ScopeExit.h" namespace mlir { #define GEN_PASS_DEF_CONVERTARMSMETOLLVM #include "mlir/Conversion/Passes.h.inc" } // namespace mlir usingnamespacemlir; namespace { static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id"); /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic. static Operation *createLoadTileSliceIntrinsic( RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, arm_sme::TileSliceLayout layout, Value maskOp, Value ptr, IntegerAttr tileId, Value tileSliceI32) { … } /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic. static Operation *createStoreTileSliceIntrinsic( RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, arm_sme::TileSliceLayout layout, Value maskOp, Value ptr, IntegerAttr tileId, Value tileSliceI32) { … } IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) { … } /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is /// placed in the first block of the function. static memref::AllocaOp createAllocaForTile(RewriterBase &rewriter, Location loc, FunctionOpInterface func, arm_sme::ArmSMETileOpInterface tileOp) { … } /// Finds or creates an alloca for a spill of a tile. static memref::AllocaOp getOrCreateAllocaForTile( RewriterBase &rewriter, Location loc, FunctionOpInterface func, arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) { … } /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning /// the op to tile 0, then emitting a full tile swap between ZA and memory /// before + after the tile op. /// /// Example: /// /// // Note: <IN MEMORY TILE> = tile ID >= 16. /// arm_sme.tile_op { tile_id = <IN MEMORY TILE> } /// /// is converted to: /// // At function entry: /// %spill = memref.alloca ... : memref<?x?xty> /// /// // Around op: /// scf.for %slice_idx { /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> /// vector.store %slice_to_save, %spill[%slice_idx, %c0] /// } /// arm_sme.tile_op { tile_id = 0 } /// scf.for %slice_idx { /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> /// vector.store %slice_to_save, %spill[%slice_idx, %c0] /// } /// /// Note that these spills/fills are not inserted earlier as concept of a /// register, and the need to swap the contents, can't really be represented /// correctly at a high level in MLIR. /// /// TODO: Reduce the spills/reloads to single slices where possible (and omit /// redundant reloads). This could be done via a method on the /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.: /// /// `tileOp.getZaUsage()` could return: /// /// struct ArmSMEOpZAUsage { /// enum class Kind { /// TileRead, // Omit store after tile operation. /// TileWrite, // Omit load before tile operation. /// TileReadWrite, // Needs both tile load and store. /// SliceRead, // Spill single slice and omit store after operation. /// SliceWrite, // Spill single slice and omit load before operation. /// SliceReadWrite // Spill single slice. /// }; /// Value sliceIndex {}; /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal }; /// }; /// struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { … }; enum class RequiresSpillsAndFills { … }; /// Base class for ArmSME to LLVM conversion patterns. By default, this adds /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be /// disabled by setting the `requiresSpillsAndFills` template parameter to /// `RequiresSpillsAndFills::No`. template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills = RequiresSpillsAndFills::Yes> struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> { … }; template <typename Pattern> static void addArmSMEConversionPattern(RewritePatternSet &patterns, LLVMTypeConverter const &typeConverter) { … } /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns. template <typename... Patterns> static void addArmSMEConversionPatterns(RewritePatternSet &patterns, LLVMTypeConverter const &typeConverter) { … } /// Lower 'arm_sme.zero' to SME intrinsics. /// /// BEFORE: /// ```mlir /// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32> /// ``` /// /// AFTER: /// ```mlir /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> () /// %v = arm_sme.get_tile : vector<[4]x[4]xi32> /// ``` /// /// The 'arm_sme.get_tile' (which models the return) will fold away once all /// ArmSME ops have been converted to LLVM intrinsics. struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> { … }; /// Lower `arm_sme.load_tile_slice` to SME intrinsics. struct LoadTileSliceConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> { … }; /// Lower for `arm_sme.store_tile_slice` to SME intrinsics. struct StoreTileSliceConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> { … }; /// Lower `arm_sme.insert_tile_slice` to SME intrinsics. struct InsertTileSliceConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> { … }; /// Lower `arm_sme.extract_tile_slice` to SME intrinsics. struct ExtractTileSliceConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> { … }; /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics. /// /// Example: /// /// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc) /// : vector<[4]xf32>, vector<[4]xf32> /// /// is converted to: /// /// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}> /// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, /// vector<[4]xf32>) -> () /// /// Currently only supports FMOPA and BFMOPA (non-widening). struct OuterProductOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> { … }; /// Lower 2-way and 4-way widening outer products to intrinsics. template <class OuterProductWideningOp, class OuterProductWideningIntrOp> struct OuterProductWideningOpConversion : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> { … }; /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics. /// /// Example: /// /// %0 = arm_sme.streaming_vl <half> /// /// is converted to: /// /// %cnt = "arm_sme.intr.cntsh"() : () -> i64 /// %0 = arith.index_cast %cnt : i64 to index /// struct StreamingVLOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp, RequiresSpillsAndFills::No> { … }; /// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise /// or-ing the zero masks. Note: In future the backend _should_ handle this. static void mergeConsecutiveTileZerosInBlock(Block *block) { … } } // namespace namespace { struct ConvertArmSMEToLLVMPass : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> { … }; } // namespace void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { … } void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { … } std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) { … }