llvm/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements lowering of ArmSME operations to LLVM intrinsics.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"

#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ScopeExit.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

usingnamespacemlir;

namespace {

static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");

/// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
static Operation *createLoadTileSliceIntrinsic(
    RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
    arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
    IntegerAttr tileId, Value tileSliceI32) {}

/// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
static Operation *createStoreTileSliceIntrinsic(
    RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
    arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
    IntegerAttr tileId, Value tileSliceI32) {}

IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {}

/// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
/// placed in the first block of the function.
static memref::AllocaOp
createAllocaForTile(RewriterBase &rewriter, Location loc,
                    FunctionOpInterface func,
                    arm_sme::ArmSMETileOpInterface tileOp) {}

/// Finds or creates an alloca for a spill of a tile.
static memref::AllocaOp getOrCreateAllocaForTile(
    RewriterBase &rewriter, Location loc, FunctionOpInterface func,
    arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {}

/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
/// the op to tile 0, then emitting a full tile swap between ZA and memory
/// before + after the tile op.
///
/// Example:
///
///    // Note: <IN MEMORY TILE> = tile ID >= 16.
///    arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
///
/// is converted to:
///     // At function entry:
///     %spill = memref.alloca ... : memref<?x?xty>
///
///     // Around op:
///     scf.for %slice_idx {
///       %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
///       "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx)  <{tile_id = 0 : i32}>
///       vector.store %slice_to_save, %spill[%slice_idx, %c0]
///     }
///     arm_sme.tile_op { tile_id = 0 }
///     scf.for %slice_idx {
///       %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
///       "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx)  <{tile_id = 0 : i32}>
///       vector.store %slice_to_save, %spill[%slice_idx, %c0]
///     }
///
/// Note that these spills/fills are not inserted earlier as concept of a
/// register, and the need to swap the contents, can't really be represented
/// correctly at a high level in MLIR.
///
/// TODO: Reduce the spills/reloads to single slices where possible (and omit
/// redundant reloads). This could be done via a method on the
/// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
///
/// `tileOp.getZaUsage()` could return:
///
/// struct ArmSMEOpZAUsage {
///   enum class Kind {
///     TileRead,        // Omit store after tile operation.
///     TileWrite,       // Omit load before tile operation.
///     TileReadWrite,   // Needs both tile load and store.
///     SliceRead,       // Spill single slice and omit store after operation.
///     SliceWrite,      // Spill single slice and omit load before operation.
///     SliceReadWrite   // Spill single slice.
///   };
///   Value sliceIndex {};
///   TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
/// };
///
struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {};

enum class RequiresSpillsAndFills {};

/// Base class for ArmSME to LLVM conversion patterns. By default, this adds
/// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
/// disabled by setting the `requiresSpillsAndFills` template parameter to
/// `RequiresSpillsAndFills::No`.
template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
                                 RequiresSpillsAndFills::Yes>
struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {};

template <typename Pattern>
static void addArmSMEConversionPattern(RewritePatternSet &patterns,
                                       LLVMTypeConverter const &typeConverter) {}

/// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
template <typename... Patterns>
static void
addArmSMEConversionPatterns(RewritePatternSet &patterns,
                            LLVMTypeConverter const &typeConverter) {}

/// Lower 'arm_sme.zero' to SME intrinsics.
///
///  BEFORE:
///  ```mlir
///     %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
///  ```
///
///  AFTER:
///  ```mlir
///     "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
///     %v = arm_sme.get_tile : vector<[4]x[4]xi32>
///  ```
///
///  The 'arm_sme.get_tile' (which models the return) will fold away once all
///  ArmSME ops have been converted to LLVM intrinsics.
struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {};

/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
struct LoadTileSliceConversion
    : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {};

/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
struct StoreTileSliceConversion
    : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {};

/// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
struct InsertTileSliceConversion
    : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {};

/// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
struct ExtractTileSliceConversion
    : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {};

/// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
///
/// Example:
///
///   %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
///     : vector<[4]xf32>, vector<[4]xf32>
///
/// is converted to:
///
///   "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
///     : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
///        vector<[4]xf32>) -> ()
///
/// Currently only supports FMOPA and BFMOPA (non-widening).
struct OuterProductOpConversion
    : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {};

/// Lower 2-way and 4-way widening outer products to intrinsics.
template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
struct OuterProductWideningOpConversion
    : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {};

/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
///
/// Example:
///
///   %0 = arm_sme.streaming_vl <half>
///
/// is converted to:
///
///   %cnt = "arm_sme.intr.cntsh"() : () -> i64
///   %0 = arith.index_cast %cnt : i64 to index
///
struct StreamingVLOpConversion
    : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
                                          RequiresSpillsAndFills::No> {};

/// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
/// or-ing the zero masks. Note: In future the backend _should_ handle this.
static void mergeConsecutiveTileZerosInBlock(Block *block) {}

} // namespace

namespace {

struct ConvertArmSMEToLLVMPass
    : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {};

} // namespace

void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {}

void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {}

std::unique_ptr<Pass>
mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {}