//===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Casting.h" usingnamespacemlir; namespace { /// Conversion pattern for vector.transfer_read. /// /// --- /// /// Example 1: op with identity permutation map to horizontal /// arm_sme.tile_load: /// /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1) /// /// is converted to: /// /// arm_sme.tile_load ... /// /// --- /// /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load /// (in-flight transpose): /// /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0) /// /// is converted to: /// /// arm_sme.tile_load ... layout<vertical> struct TransferReadToArmSMELowering : public OpRewritePattern<vector::TransferReadOp> { … }; /// Conversion pattern for vector.transfer_write. /// /// --- /// /// Example 1: op with identity permutation map to horizontal /// arm_sme.tile_store: /// /// vector.transfer_write %vector, %source[%c0, %c0] /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> /// /// is converted to: /// /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>, /// vector<[16]x[16]xi8> /// --- /// /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store /// (in-flight transpose): /// /// vector.transfer_write %vector, %source[%c0, %c0] /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> /// /// is converted to: /// /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical> /// : memref<?x?xi8>, vector<[16]x[16]xi8> struct TransferWriteToArmSMELowering : public OpRewritePattern<vector::TransferWriteOp> { … }; /// Conversion pattern for vector.load. struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> { … }; /// Conversion pattern for vector.store. struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> { … }; /// Conversion pattern for vector.broadcast. /// /// Example: /// /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32> /// /// is converted to: /// /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) /// { /// %tile_update = arm_sme.insert_tile_slice /// %broadcast_to_1d, %iter_tile[%tile_slice_index] : /// vector<[4]xi32> into vector<[4]x[4]xi32> /// scf.yield %tile_update : vector<[4]x[4]xi32> /// } /// /// Supports scalar, 0-d vector, and 1-d vector broadcasts. struct BroadcastOpToArmSMELowering : public OpRewritePattern<vector::BroadcastOp> { … }; /// Conversion pattern for vector.splat. /// /// Example: /// /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32> /// /// is converted to: /// /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) /// { /// %tile_update = arm_sme.insert_tile_slice /// %broadcast_to_1d, %iter_tile[%tile_slice_index] : /// vector<[4]xi32> into vector<[4]x[4]xi32> /// scf.yield %tile_update : vector<[4]x[4]xi32> /// } /// /// This is identical to vector.broadcast of a scalar. struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> { … }; /// Conversion pattern for vector.transpose. /// /// Stores the input tile to memory and reloads vertically. /// /// Example: /// /// %transposed_src = vector.transpose %src, [1, 0] /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> /// /// is converted to: /// /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32> /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0] /// : memref<?x?xi32>, vector<[4]x[4]xi32> /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0] /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32> /// /// NOTE: Tranposing via memory is obviously expensive, the current intention /// is to avoid the transpose if possible, this is therefore intended as a /// fallback and to provide base support for Vector ops. If it turns out /// transposes can't be avoided then this should be replaced with a more optimal /// implementation, perhaps with tile <-> vector (MOVA) ops. struct TransposeOpToArmSMELowering : public OpRewritePattern<vector::TransposeOp> { … }; /// Conversion pattern for vector.outerproduct. /// /// If the vector.outerproduct is masked (and the mask is from a /// vector.create_mask), then the mask is decomposed into two 1-D masks for the /// operands. /// /// Example: /// /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1> /// %result = vector.mask %mask { /// vector.outerproduct %vecA, %vecB /// : vector<[4]xf32>, vector<[4]xf32> /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> /// /// is converted to: /// /// %maskA = vector.create_mask %dimA : vector<[4]xi1> /// %maskB = vector.create_mask %dimB : vector<[4]xi1> /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) /// : vector<[4]xf32>, vector<[4]xf32> /// /// Unmasked outerproducts can be directly replaced with the arm_sme op. /// /// Example: /// /// %result = vector.outerproduct %vecA, %vecB /// : vector<[4]xf32>, vector<[4]xf32> /// /// is converted to: /// /// %result = arm_sme.outerproduct %vecA, %vecB /// : vector<[4]xf32>, vector<[4]xf32> /// struct VectorOuterProductToArmSMELowering : public OpRewritePattern<vector::OuterProductOp> { … }; /// Lower `vector.extract` using `arm_sme.extract_tile_slice`. /// /// Example: /// ``` /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32> /// ``` /// Becomes: /// ``` /// %slice = arm_sme.extract_tile_slice %tile[%row] /// : vector<[4]xi32> from vector<[4]x[4]xi32> /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32> /// ``` struct VectorExtractToArmSMELowering : public OpRewritePattern<vector::ExtractOp> { … }; /// Lower `vector.insert` using `arm_sme.insert_tile_slice` and /// `arm_sme.extract_tile_slice`. /// /// Example: /// ``` /// %new_tile = vector.insert %el, %tile[%row, %col] /// : i32 into vector<[4]x[4]xi32> /// ``` /// Becomes: /// ``` /// %slice = arm_sme.extract_tile_slice %tile[%row] /// : vector<[4]xi32> from vector<[4]x[4]xi32> /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32> /// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row] /// : vector<[4]xi32> into vector<[4]x[4]xi32> /// ``` struct VectorInsertToArmSMELowering : public OpRewritePattern<vector::InsertOp> { … }; /// Lowers `vector.print` of a tile into a loop over the rows of the tile, /// extracting them via `arm_sme.extract_tile_slice`, then printing with /// a 1D `vector.print`. /// /// BEFORE: /// ```mlir /// vector.print %tile : vector<[4]x[4]xf32> /// ``` /// AFTER: /// ```mlir /// %c0 = arith.constant 0 : index /// %c1 = arith.constant 1 : index /// %c4 = arith.constant 4 : index /// %vscale = vector.vscale /// %svl_s = arith.muli %c4, %vscale : index /// scf.for %i = %c0 to %svl_s step %c1 { /// %tile_slice = arm_sme.extract_tile_slice %tile[%i] /// : vector<[4]xf32> from vector<[4]x[4]xf32> /// vector.print %tile_slice : vector<[4]xf32> /// } /// ``` struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> { … }; /// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp. /// /// BEFORE: /// ```mlir /// %slice = arm_sme.extract_tile_slice %tile[%index] /// : vector<[4]xf32> from vector<[4]x[4]xf32> /// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]} /// : vector<[4]xf32>, memref<?x?xf32> /// ``` /// AFTER: /// ```mlir /// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j] /// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32> /// ``` struct FoldTransferWriteOfExtractTileSlice : public OpRewritePattern<vector::TransferWriteOp> { … }; /// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to /// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or /// SVE 2.1), so this is currently the most logical place for this lowering. /// /// Example: /// ```mlir /// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> /// %slice = vector.extract %mask[%index] /// : vector<[8]xi1> from vector<[4]x[8]xi1> /// ``` /// Becomes: /// ``` /// %mask_rows = vector.create_mask %a : vector<[4]xi1> /// %mask_cols = vector.create_mask %b : vector<[8]xi1> /// %slice = arm_sve.psel %mask_cols, %mask_rows[%index] /// : vector<[8]xi1>, vector<[4]xi1> /// ``` struct ExtractFromCreateMaskToPselLowering : public OpRewritePattern<vector::ExtractOp> { … }; } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { … }