llvm/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass legalizes vector operations so they can be lowered to ArmSME.
//
// Note: In the context of this pass 'tile' always refers to an SME tile.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Transforms/OneToNTypeConversion.h"

#define DEBUG_TYPE

namespace mlir::arm_sme {
#define GEN_PASS_DEF_VECTORLEGALIZATION
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
} // namespace mlir::arm_sme

usingnamespacemlir;
usingnamespacemlir::arm_sme;

namespace {

//===----------------------------------------------------------------------===//
// Decomposition of vector operations larger than an SME tile
//===----------------------------------------------------------------------===//

// Common match failure reasons.
static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
    "op vector size is not multiple of SME tiles");
static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
    "op mask is unsupported for legalization/decomposition");
static constexpr StringLiteral
    kMatchFailureNonPermutationMap("op affine map is not a permutation");
static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
    "expected transpose from illegal type to legal type");

/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
/// larger vector type. The (`row`, `col`) are the position of the tile in the
/// original vector type. For example for an [8]x[8] tile with four [4]x[4]
/// sub-tiles, we would have:
///
///           8 x vscale
/// ┌─────────────┬─────────────┐
/// │(0,0)        │(0,4)        │
/// │             │             │
/// ├─────────────┼─────────────┤ 8 x vscale
/// │(4,0)        │(4,4)        │
/// │             │             │
/// └─────────────┴─────────────┘
struct SMESubTile {};

/// Adds a constant elementwise scalable offset to `indices` (which are of equal
/// length). For example, in the 2D case this would return:
// { indices[0] + offset[0] * vscale, indices[1] + offset[1] *  vscale }
SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
                                                Location loc,
                                                ValueRange indices,
                                                ArrayRef<int> scalableOffsets) {}

/// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
/// indices for one of the SME sub-tiles it will decompose into.
///
/// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
/// indices for each tile would need to be adjusted as follows:
///
/// initial indices = [a,b], inital size = 8x8, target size = 4x4
/// ┌─────────────┬─────────────┐
/// │[a,b]        │[a,b+4]      │
/// │             │             │
/// ├─────────────┼─────────────┤
/// │[a+4,b]      │[a+4,b+4]    │
/// │             │             │
/// └─────────────┴─────────────┘
SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
                                           ValueRange indices,
                                           SMESubTile smeTile) {}

/// Returns true if `mask` is generated by an operation that can be decomposed
/// for SME. Currently, that is just no mask, or vector.create_mask.
/// TODO: Add support for vector.constant_mask once required for SME.
bool isSupportedMaskOp(Value mask) {}

/// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
                     SMESubTile smeTile) {}

/// Constructs an iterator that returns each SME tile (with coordinates)
/// contained within a VectorType. For example, if decomposing an [8]x[8] into
/// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
/// (4, 4).
auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
                         VectorType smeTileType,
                         bool transposeIndices = false) {}

/// Returns the number of SME tiles that fit into the (2D-scalable) vector type
/// `type`.
int getNumberOfSMETilesForVectorType(VectorType type) {}

/// Legalize `arith.constant dense<value>` splat operations to fit within SME
/// tiles by decomposing them into tile-sized operations.
struct LegalizeArithConstantOpsByDecomposition
    : public OneToNOpConversionPattern<arith::ConstantOp> {};

/// Legalize `vector.outerproduct` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeVectorOuterProductOpsByDecomposition
    : public OneToNOpConversionPattern<vector::OuterProductOp> {};

// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
// get the help of the type conversion), but doing so results in the type
// conversion adding target materializations in the `vector.mask` region
// (invalid). This pattern matches on `vector.mask` then calls into the
// `vector.outerproduct` pattern to work around this issue.
struct LegalizeMaskedVectorOuterProductOpsByDecomposition
    : public OneToNOpConversionPattern<vector::MaskOp> {};

/// Legalize `vector.transfer_read` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferReadOpsByDecomposition
    : public OneToNOpConversionPattern<vector::TransferReadOp> {};

/// Legalize `vector.transfer_write` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferWriteOpsByDecomposition
    : public OneToNOpConversionPattern<vector::TransferWriteOp> {};

/// Legalize a multi-tile transfer_write as a single store loop. This is done as
/// part of type decomposition as at this level we know each tile write is
/// disjoint, but that information is lost after decomposition (without analysis
/// to reconstruct it).
///
/// Example (pseudo-MLIR):
///
/// ```
/// vector.transfer_write %vector, %dest[%y, %x], %mask
///   : vector<[16]x[8]xi16>, memref<?x?xi16>
/// ```
/// Is rewritten to:
/// ```
/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
///   %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
///     : vector<[8]xi1> from vector<[16]x[8]xi1>           |
///   %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
///     : vector<[8]xi16> from vector<[8]x[8]xi16>          |
///   vector.transfer_write %upper_slice,                   |
///     %dest[%slice_idx + %y, %x], %upper_slice_mask       |
///     : vector<[8]xi16>, memref<?x?xi16>                  ┘
///   %lower_slice_idx = %slice_idx + %c8_vscale                 ─┐
///   %lower_slice_mask = vector.extract %mask[%lower_slice_idx]  |
///     : vector<[8]xi1> from vector<[16]x[8]xi1>                 |
///   %lower_slice = vector.extract %lower_tile[%slice_idx]       |- Store lower
///     : vector<[8]xi16> from vector<[8]x[8]xi16>                |  tile
///   vector.transfer_write %lower_slice,                         |
///     %dest[%lower_slice_idx + %y, %x], %lower_slice_mask       |
///     : vector<[8]xi16>, memref<?x?xi16>                        ┘
/// }
/// ```
struct LegalizeMultiTileTransferWriteAsStoreLoop
    : public OneToNOpConversionPattern<vector::TransferWriteOp> {};

//===----------------------------------------------------------------------===//
// ArmSME-specific fixup canonicalizations/folds
//===----------------------------------------------------------------------===//

/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
/// necessary for the mask to be lowered to ArmSME.
///
/// Example:
///
///  BEFORE:
///  ```mlir
///  %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
///  %subMask = vector.extract %mask[2]
///          : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
///  ```
///
///  AFTER:
///  ```mlir
///  %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
///  %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
///  %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
///  ```
struct FoldExtractFromVectorOfSMELikeCreateMasks
    : public OpRewritePattern<vector::ExtractOp> {};

/// A vector type where no fixed dimension comes after a scalable dimension.
bool isLegalVectorType(VectorType vType) {}

/// Lifts an illegal vector.transpose and vector.transfer_read to a
/// memref.subview + memref.transpose, followed by a legal read.
///
/// 'Illegal' here means a leading scalable dimension and a fixed trailing
/// dimension, which has no valid lowering.
///
/// The memref.transpose is metadata-only transpose that produces a strided
/// memref, which eventually becomes a loop reading individual elements.
///
/// Example:
///
///  BEFORE:
///  ```mlir
///  %illegalRead = vector.transfer_read %memref[%a, %b]
///                  : memref<?x?xf32>, vector<[8]x4xf32>
///  %legalType = vector.transpose %illegalRead, [1, 0]
///                  : vector<[8]x4xf32> to vector<4x[8]xf32>
///  ```
///
///  AFTER:
///  ```mlir
///  %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
///                  : memref<?x?xf32> to memref<?x?xf32>
///  %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
///                  : memref<?x?xf32> to memref<?x?xf32>
///  %legalType = vector.transfer_read %transpose[%c0, %c0]
///                  : memref<?x?xf32>, vector<4x[8]xf32>
///  ```
struct LiftIllegalVectorTransposeToMemory
    : public OpRewritePattern<vector::TransposeOp> {};

/// A rewrite to turn unit dim transpose-like vector.shape_casts into
/// vector.transposes. The shape_cast has to be from an illegal vector type to a
/// legal one (as defined by isLegalVectorType).
///
/// The reasoning for this is if we've got to this pass and we still have
/// shape_casts of illegal types, then they likely will not cancel out. Turning
/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
/// eliminate them.
///
/// Example:
///
///  BEFORE:
///  ```mlir
///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
///  ```
///
///  AFTER:
///  ```mlir
///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
///  ```
struct ConvertIllegalShapeCastOpsToTransposes
    : public OpRewritePattern<vector::ShapeCastOp> {};

/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
///
/// Example:
///
///  BEFORE:
///  ```mlir
///  %transpose = vector.transpose %vec, [1, 0]
///     : vector<2x[4]xf32> to vector<[4]x2xf32>
///  vector.transfer_write %transpose, %dest[%y, %x]
///     : vector<[4]x2xf32>,  memref<?x?xf32>
///  ```
///
///  AFTER:
///  ```mlir
///   %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
///   %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
///   %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
///   %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
///   %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
///   %c4_vscale = arith.muli %vscale, %c4 : index
///   %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
///   vector.transfer_write %4, %dest[%y, %x], %mask
///      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
///      : vector<[4]x[4]xf32>, memref<?x?xf32>
///  ```
///
/// Values larger than a single tile are supported via decomposition.
struct LowerIllegalTransposeStoreViaZA
    : public OpRewritePattern<vector::TransferWriteOp> {};

struct VectorLegalizationPass
    : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {};

} // namespace

std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {}