//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This transform allocates SME tiles at the 'func.func' op level for ArmSME // operations. It roughly implements a linear scan register allocator, similar // to the one outlined in [1], but with simplifications and assumptions made for // our use case. Note that this is a greedy allocator (so it may not always find // the most optimal allocation of tiles). // // The allocator operates at the CF dialect level. It is the responsibility of // users to ensure the IR has been lowered to CF before invoking the tile // allocator. // // The 128-bit tiles overlap with other element tiles as follows (see section // B2.3.2 of SME spec [2]): // // Tile Overlaps // --------------------------------------------------------------------------- // ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q, // ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q // ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q // ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q // ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q // ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q // ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q // ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q // ZA0.D ZA0.Q, ZA8.Q // ZA1.D ZA1.Q, ZA9.Q // ZA2.D ZA2.Q, ZA10.Q // ZA3.D ZA3.Q, ZA11.Q // ZA4.D ZA4.Q, ZA12.Q // ZA5.D ZA5.Q, ZA13.Q // ZA6.D ZA6.Q, ZA14.Q // ZA7.D ZA7.Q, ZA15.Q // // [1] "Linear Scan Register Allocation in the Context of SSA Form and Register // Constraints" (Hanspeter Mössenböck and Michael Pfeiffer) // https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf // [2] https://developer.arm.com/documentation/ddi0616/aa // //===----------------------------------------------------------------------===// #include "mlir/Analysis/Liveness.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/TypeSwitch.h" #include <algorithm> namespace mlir::arm_sme { #define GEN_PASS_DEF_TESTTILEALLOCATION #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" } // namespace mlir::arm_sme usingnamespacemlir; usingnamespacemlir::arm_sme; namespace { enum class TileMask : unsigned { … }; /// Returns the set of masks relevant for the given type. static ArrayRef<TileMask> getMasks(ArmSMETileType type) { … } class TileAllocator { … }; /// Add new intermediate blocks for the true and false destinations of /// `cf.cond_br`s that contain tile operands. This prevents spurious liveness /// overlaps due to copies at branches. /// /// BEFORE: /// ```mlir /// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2 /// ``` /// /// AFTER: /// ```mlir /// cf.cond_br %cond, ^bb1_copy, ^bb2_copy /// ^bb1_copy: /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) /// ^bb2_copy: /// cf.br ^bb2 /// ``` void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { … } /// Inserts tile copies at `cf.br` operations. /// /// BEFORE: /// ```mlir /// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) /// ``` /// /// AFTER: /// ```mlir /// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32> /// cf.br ^bb1(%copy: vector<[4]x[4]xf32>) /// ``` void insertCopiesAtBranches(IRRewriter &rewriter, FunctionOpInterface function) { … } /// Prepares the IR for tile allocation. It does this by first 'splitting' /// conditional branches (see `splitCondBranches`), then inserting tile copies /// at branch operations. The conditional branches are split to prevent the /// copies needed for them overlapping between the true and false paths of the /// branch (see `tile-allocation-copies.mlir` and /// `tile-allocation-liveness.mlir` for examples). The copies break up live /// ranges and ensure when moving out of SSA the semantics of the program are /// preserved. void preprocessForTileAllocation(IRRewriter &rewriter, FunctionOpInterface function) { … } /// A live range for a (collection of) tile values. A live range is built up of /// non-overlapping intervals [start, end) which represent parts of the program /// where a value in the range needs to be live (i.e. in an SME virtual tile). /// Note that as the intervals are non-overlapping all values within a live /// range can be allocated to the same SME virtual tile. struct LiveRange { … }; /// Number operations within a function to allow computing live ranges. /// Operations are numbered consecutively wihin blocks, and the blocks are /// topologically sorted (using forward edges). This function is only correct if /// all ArmSME have been converted to CF (which is asserted). DenseMap<Operation *, unsigned> generateOperationNumbering(FunctionOpInterface function) { … } /// Gather live ranges for SME tiles from the MLIR liveness analysis. DenseMap<Value, LiveRange> gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, LiveRange::Allocator &liveRangeAllocator, Liveness &liveness, FunctionOpInterface function) { … } /// Iterate over all predecessor tile values to a (tile) block argument. static void forEachPredecessorTileValue(BlockArgument blockArg, function_ref<void(Value)> callback) { … } /// Coalesce live ranges where it would prevent unnecessary tile moves. SmallVector<LiveRange *> coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) { … } /// Choose a live range to spill (via some heuristics). This picks either a live /// range from `overlappingRanges`, or the new live range `newRange`. template <typename OverlappingRangesIterator> LiveRange * chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges, LiveRange *newRange) { … } /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics. void allocateTilesToLiveRanges( ArrayRef<LiveRange *> liveRangesSortedByStartPoint) { … } /// Assigns a tile ID to an MLIR value. void assignTileIdToValue(IRRewriter &rewriter, Value value, IntegerAttr tileIdAttr) { … } /// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts. LogicalResult assignTileIdsAndResolveTrivialConflicts( IRRewriter &rewriter, FunctionOpInterface function, ArrayRef<LiveRange *> allocatedLiveRanges) { … } /// Prints live ranges alongside operation names for debugging. void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, ArrayRef<LiveRange const *> liveRanges, FunctionOpInterface function) { … } struct TestTileAllocationPass : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> { … }; } // namespace LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function, bool dumpRanges) { … }