llvm/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

//===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"

#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Casting.h"

usingnamespacemlir;

namespace {

/// Conversion pattern for vector.transfer_read.
///
/// ---
///
/// Example 1: op with identity permutation map to horizontal
///            arm_sme.tile_load:
///
///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d0, d1)
///
/// is converted to:
///
///   arm_sme.tile_load ...
///
/// ---
///
/// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
///            (in-flight transpose):
///
///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d1, d0)
///
/// is converted to:
///
///   arm_sme.tile_load ... layout<vertical>
struct TransferReadToArmSMELowering
    : public OpRewritePattern<vector::TransferReadOp> {};

/// Conversion pattern for vector.transfer_write.
///
/// ---
///
/// Example 1: op with identity permutation map to horizontal
///            arm_sme.tile_store:
///
///   vector.transfer_write %vector, %source[%c0, %c0]
///     {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
///
/// is converted to:
///
///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
///                                                   vector<[16]x[16]xi8>
/// ---
///
/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
///            (in-flight transpose):
///
///   vector.transfer_write %vector, %source[%c0, %c0]
///     {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
///      in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
///
/// is converted to:
///
///   arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
///     : memref<?x?xi8>, vector<[16]x[16]xi8>
struct TransferWriteToArmSMELowering
    : public OpRewritePattern<vector::TransferWriteOp> {};

/// Conversion pattern for vector.load.
struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {};

/// Conversion pattern for vector.store.
struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {};

/// Conversion pattern for vector.broadcast.
///
/// Example:
///
///   %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
///
/// is converted to:
///
///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
///   {
///     %tile_update = arm_sme.move_vector_to_tile_slice
///        %broadcast_to_1d, %iter_tile, %tile_slice_index :
///        vector<[4]xi32> into vector<[4]x[4]xi32>
///     scf.yield %tile_update : vector<[4]x[4]xi32>
///   }
///
/// Supports scalar, 0-d vector, and 1-d vector broadcasts.
struct BroadcastOpToArmSMELowering
    : public OpRewritePattern<vector::BroadcastOp> {};

/// Conversion pattern for vector.splat.
///
/// Example:
///
///   %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
///
/// is converted to:
///
///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
///   {
///     %tile_update = arm_sme.move_vector_to_tile_slice
///        %broadcast_to_1d, %iter_tile, %tile_slice_index :
///        vector<[4]xi32> into vector<[4]x[4]xi32>
///     scf.yield %tile_update : vector<[4]x[4]xi32>
///   }
///
/// This is identical to vector.broadcast of a scalar.
struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {};

/// Conversion pattern for vector.transpose.
///
/// Stores the input tile to memory and reloads vertically.
///
/// Example:
///
///   %transposed_src = vector.transpose %src, [1, 0]
///     : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
///
/// is converted to:
///
///   %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
///   %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
///     : memref<?x?xi32>, vector<[4]x[4]xi32>
///   %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
///     layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
///
/// NOTE: Tranposing via memory is obviously expensive, the current intention
/// is to avoid the transpose if possible, this is therefore intended as a
/// fallback and to provide base support for Vector ops. If it turns out
/// transposes can't be avoided then this should be replaced with a more optimal
/// implementation, perhaps with tile <-> vector (MOVA) ops.
struct TransposeOpToArmSMELowering
    : public OpRewritePattern<vector::TransposeOp> {};

/// Conversion pattern for vector.outerproduct.
///
/// If the vector.outerproduct is masked (and the mask is from a
/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
/// operands.
///
/// Example:
///
///   %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
///   %result = vector.mask %mask {
///                vector.outerproduct %vecA, %vecB
///                 : vector<[4]xf32>, vector<[4]xf32>
///             } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
///
/// is converted to:
///
///    %maskA = vector.create_mask %dimA : vector<[4]xi1>
///    %maskB = vector.create_mask %dimB : vector<[4]xi1>
///    %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
///                : vector<[4]xf32>, vector<[4]xf32>
///
/// Unmasked outerproducts can be directly replaced with the arm_sme op.
///
/// Example:
///
///   %result = vector.outerproduct %vecA, %vecB
///              : vector<[4]xf32>, vector<[4]xf32>
///
/// is converted to:
///
///   %result = arm_sme.outerproduct %vecA, %vecB
///              : vector<[4]xf32>, vector<[4]xf32>
///
struct VectorOuterProductToArmSMELowering
    : public OpRewritePattern<vector::OuterProductOp> {};

/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
///
/// Example:
/// ```
/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
/// ```
/// Becomes:
/// ```
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
///            : vector<[4]xi32> from vector<[4]x[4]xi32>
/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
/// ```
struct VectorExtractToArmSMELowering
    : public OpRewritePattern<vector::ExtractOp> {};

/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
/// `arm_sme.move_tile_slice_to_vector`.
///
/// Example:
/// ```
/// %new_tile = vector.insert %el, %tile[%row, %col]
///                     : i32 into vector<[4]x[4]xi32>
/// ```
/// Becomes:
/// ```
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
///            : vector<[4]xi32> from vector<[4]x[4]xi32>
/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
///               : vector<[4]xi32> into vector<[4]x[4]xi32>
/// ```
struct VectorInsertToArmSMELowering
    : public OpRewritePattern<vector::InsertOp> {};

/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
/// a 1D `vector.print`.
///
///  BEFORE:
///  ```mlir
///  vector.print %tile : vector<[4]x[4]xf32>
///  ```
///  AFTER:
///  ```mlir
///  %c0 = arith.constant 0 : index
///  %c1 = arith.constant 1 : index
///  %c4 = arith.constant 4 : index
///  %vscale = vector.vscale
///  %svl_s = arith.muli %c4, %vscale : index
///  scf.for %i = %c0 to %svl_s step %c1 {
///    %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
///                     : vector<[4]xf32> from vector<[4]x[4]xf32>
///    vector.print %tile_slice : vector<[4]xf32>
///  }
///  ```
struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {};

/// Folds a MoveTileSliceToVectorOp + TransferWriteOp to a StoreTileSliceOp.
///
///  BEFORE:
///  ```mlir
///  %slice = arm_sme.move_tile_slice_to_vector %tile[%index]
///             : vector<[4]xf32> from vector<[4]x[4]xf32>
///  vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
///             : vector<[4]xf32>, memref<?x?xf32>
///  ```
///  AFTER:
///  ```mlir
///  arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
///             : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
///  ```
struct FoldTransferWriteOfExtractTileSlice
    : public OpRewritePattern<vector::TransferWriteOp> {};

/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
/// SVE 2.1), so this is currently the most logical place for this lowering.
///
/// Example:
/// ```mlir
/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
/// %slice = vector.extract %mask[%index]
///            : vector<[8]xi1> from vector<[4]x[8]xi1>
/// ```
/// Becomes:
/// ```
/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
///            : vector<[8]xi1>, vector<[4]xi1>
/// ```
struct ExtractFromCreateMaskToPselLowering
    : public OpRewritePattern<vector::ExtractOp> {};

} // namespace

void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                          MLIRContext &ctx) {}