llvm/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp

//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transform allocates SME tiles at the 'func.func' op level for ArmSME
// operations. It roughly implements a linear scan register allocator, similar
// to the one outlined in [1], but with simplifications and assumptions made for
// our use case. Note that this is a greedy allocator (so it may not always find
// the most optimal allocation of tiles).
//
// The allocator operates at the CF dialect level. It is the responsibility of
// users to ensure the IR has been lowered to CF before invoking the tile
// allocator.
//
// The 128-bit tiles overlap with other element tiles as follows (see section
// B2.3.2 of SME spec [2]):
//
//   Tile    Overlaps
//   ---------------------------------------------------------------------------
//   ZA0.B   ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
//           ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
//   ZA0.H   ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
//   ZA1.H   ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
//   ZA0.S   ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
//   ZA1.S   ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
//   ZA2.S   ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
//   ZA3.S   ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
//   ZA0.D   ZA0.Q, ZA8.Q
//   ZA1.D   ZA1.Q, ZA9.Q
//   ZA2.D   ZA2.Q, ZA10.Q
//   ZA3.D   ZA3.Q, ZA11.Q
//   ZA4.D   ZA4.Q, ZA12.Q
//   ZA5.D   ZA5.Q, ZA13.Q
//   ZA6.D   ZA6.Q, ZA14.Q
//   ZA7.D   ZA7.Q, ZA15.Q
//
// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
//      Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
//     https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
// [2] https://developer.arm.com/documentation/ddi0616/aa
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>

namespace mlir::arm_sme {
#define GEN_PASS_DEF_TESTTILEALLOCATION
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
} // namespace mlir::arm_sme

usingnamespacemlir;
usingnamespacemlir::arm_sme;

namespace {

enum class TileMask : unsigned {};

/// Returns the set of masks relevant for the given type.
static ArrayRef<TileMask> getMasks(ArmSMETileType type) {}

class TileAllocator {};

/// Add new intermediate blocks for the true and false destinations of
/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
/// overlaps due to copies at branches.
///
///  BEFORE:
///  ```mlir
///  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
///  ```
///
///  AFTER:
///  ```mlir
///    cf.cond_br %cond, ^bb1_copy, ^bb2_copy
///  ^bb1_copy:
///    cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
///  ^bb2_copy:
///    cf.br ^bb2
///  ```
void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {}

/// Inserts tile copies at `cf.br` operations.
///
///  BEFORE:
///  ```mlir
///  cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
///  ```
///
///  AFTER:
///  ```mlir
///  %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
///  cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
///  ```
void insertCopiesAtBranches(IRRewriter &rewriter,
                            FunctionOpInterface function) {}

/// Prepares the IR for tile allocation. It does this by first 'splitting'
/// conditional branches (see `splitCondBranches`), then inserting tile copies
/// at branch operations. The conditional branches are split to prevent the
/// copies needed for them overlapping between the true and false paths of the
/// branch (see `tile-allocation-copies.mlir` and
/// `tile-allocation-liveness.mlir` for examples). The copies break up live
/// ranges and ensure when moving out of SSA the semantics of the program are
/// preserved.
void preprocessForTileAllocation(IRRewriter &rewriter,
                                 FunctionOpInterface function) {}

/// A live range for a (collection of) tile values. A live range is built up of
/// non-overlapping intervals [start, end) which represent parts of the program
/// where a value in the range needs to be live (i.e. in an SME virtual tile).
/// Note that as the intervals are non-overlapping all values within a live
/// range can be allocated to the same SME virtual tile.
struct LiveRange {};

/// Number operations within a function to allow computing live ranges.
/// Operations are numbered consecutively wihin blocks, and the blocks are
/// topologically sorted (using forward edges). This function is only correct if
/// all ArmSME have been converted to CF (which is asserted).
DenseMap<Operation *, unsigned>
generateOperationNumbering(FunctionOpInterface function) {}

/// Gather live ranges for SME tiles from the MLIR liveness analysis.
DenseMap<Value, LiveRange>
gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
                     LiveRange::Allocator &liveRangeAllocator,
                     Liveness &liveness, FunctionOpInterface function) {}

/// Iterate over all predecessor tile values to a (tile) block argument.
static void forEachPredecessorTileValue(BlockArgument blockArg,
                                        function_ref<void(Value)> callback) {}

/// Coalesce live ranges where it would prevent unnecessary tile moves.
SmallVector<LiveRange *>
coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {}

/// Choose a live range to spill (via some heuristics). This picks either a live
/// range from `overlappingRanges`, or the new live range `newRange`.
template <typename OverlappingRangesIterator>
LiveRange *
chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
                           LiveRange *newRange) {}

/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
void allocateTilesToLiveRanges(
    ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {}

/// Assigns a tile ID to an MLIR value.
void assignTileIdToValue(IRRewriter &rewriter, Value value,
                         IntegerAttr tileIdAttr) {}

/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
LogicalResult assignTileIdsAndResolveTrivialConflicts(
    IRRewriter &rewriter, FunctionOpInterface function,
    ArrayRef<LiveRange *> allocatedLiveRanges) {}

/// Prints live ranges alongside operation names for debugging.
void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
                    ArrayRef<LiveRange const *> liveRanges,
                    FunctionOpInterface function) {}

struct TestTileAllocationPass
    : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {};
} // namespace

LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
                                              bool dumpRanges) {}