//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass legalizes vector operations so they can be lowered to ArmSME. // // Note: In the context of this pass 'tile' always refers to an SME tile. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Transforms/OneToNTypeConversion.h" #define DEBUG_TYPE … namespace mlir::arm_sme { #define GEN_PASS_DEF_VECTORLEGALIZATION #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" } // namespace mlir::arm_sme usingnamespacemlir; usingnamespacemlir::arm_sme; namespace { //===----------------------------------------------------------------------===// // Decomposition of vector operations larger than an SME tile //===----------------------------------------------------------------------===// // Common match failure reasons. static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple( "op vector size is not multiple of SME tiles"); static constexpr StringLiteral kMatchFailureUnsupportedMaskOp( "op mask is unsupported for legalization/decomposition"); static constexpr StringLiteral kMatchFailureNonPermutationMap("op affine map is not a permutation"); static constexpr StringLiteral kMatchFailureNotIllegalToLegal( "expected transpose from illegal type to legal type"); /// An SMESubTile represents a single SME-sized sub-tile from decomposing a /// larger vector type. The (`row`, `col`) are the position of the tile in the /// original vector type. For example for an [8]x[8] tile with four [4]x[4] /// sub-tiles, we would have: /// /// 8 x vscale /// ┌─────────────┬─────────────┐ /// │(0,0) │(0,4) │ /// │ │ │ /// ├─────────────┼─────────────┤ 8 x vscale /// │(4,0) │(4,4) │ /// │ │ │ /// └─────────────┴─────────────┘ struct SMESubTile { … }; /// Adds a constant elementwise scalable offset to `indices` (which are of equal /// length). For example, in the 2D case this would return: // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale } SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder, Location loc, ValueRange indices, ArrayRef<int> scalableOffsets) { … } /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to /// indices for one of the SME sub-tiles it will decompose into. /// /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the /// indices for each tile would need to be adjusted as follows: /// /// initial indices = [a,b], inital size = 8x8, target size = 4x4 /// ┌─────────────┬─────────────┐ /// │[a,b] │[a,b+4] │ /// │ │ │ /// ├─────────────┼─────────────┤ /// │[a+4,b] │[a+4,b+4] │ /// │ │ │ /// └─────────────┴─────────────┘ SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc, ValueRange indices, SMESubTile smeTile) { … } /// Returns true if `mask` is generated by an operation that can be decomposed /// for SME. Currently, that is just no mask, or vector.create_mask. /// TODO: Add support for vector.constant_mask once required for SME. bool isSupportedMaskOp(Value mask) { … } /// Extracts a mask for an SME sub-tile from the mask of a larger vector type. Value extractSMEMask(OpBuilder &builder, Location loc, Value mask, SMESubTile smeTile) { … } /// Constructs an iterator that returns each SME tile (with coordinates) /// contained within a VectorType. For example, if decomposing an [8]x[8] into /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0), /// (4, 4). auto decomposeToSMETiles(OpBuilder &builder, VectorType type, VectorType smeTileType, bool transposeIndices = false) { … } /// Returns the number of SME tiles that fit into the (2D-scalable) vector type /// `type`. int getNumberOfSMETilesForVectorType(VectorType type) { … } /// Legalize `arith.constant dense<value>` splat operations to fit within SME /// tiles by decomposing them into tile-sized operations. struct LegalizeArithConstantOpsByDecomposition : public OneToNOpConversionPattern<arith::ConstantOp> { … }; /// Legalize `vector.outerproduct` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeVectorOuterProductOpsByDecomposition : public OneToNOpConversionPattern<vector::OuterProductOp> { … }; // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to // get the help of the type conversion), but doing so results in the type // conversion adding target materializations in the `vector.mask` region // (invalid). This pattern matches on `vector.mask` then calls into the // `vector.outerproduct` pattern to work around this issue. struct LegalizeMaskedVectorOuterProductOpsByDecomposition : public OneToNOpConversionPattern<vector::MaskOp> { … }; /// Legalize `vector.transfer_read` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeTransferReadOpsByDecomposition : public OneToNOpConversionPattern<vector::TransferReadOp> { … }; /// Legalize `vector.transfer_write` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeTransferWriteOpsByDecomposition : public OneToNOpConversionPattern<vector::TransferWriteOp> { … }; /// Legalize a multi-tile transfer_write as a single store loop. This is done as /// part of type decomposition as at this level we know each tile write is /// disjoint, but that information is lost after decomposition (without analysis /// to reconstruct it). /// /// Example (pseudo-MLIR): /// /// ``` /// vector.transfer_write %vector, %dest[%y, %x], %mask /// : vector<[16]x[8]xi16>, memref<?x?xi16> /// ``` /// Is rewritten to: /// ``` /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 { /// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐ /// : vector<[8]xi1> from vector<[16]x[8]xi1> | /// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile /// : vector<[8]xi16> from vector<[8]x[8]xi16> | /// vector.transfer_write %upper_slice, | /// %dest[%slice_idx + %y, %x], %upper_slice_mask | /// : vector<[8]xi16>, memref<?x?xi16> ┘ /// %lower_slice_idx = %slice_idx + %c8_vscale ─┐ /// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] | /// : vector<[8]xi1> from vector<[16]x[8]xi1> | /// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower /// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile /// vector.transfer_write %lower_slice, | /// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask | /// : vector<[8]xi16>, memref<?x?xi16> ┘ /// } /// ``` struct LegalizeMultiTileTransferWriteAsStoreLoop : public OneToNOpConversionPattern<vector::TransferWriteOp> { … }; //===----------------------------------------------------------------------===// // ArmSME-specific fixup canonicalizations/folds //===----------------------------------------------------------------------===// /// Folds an extract from a 3D `vector.create_mask` (which is a vector of /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is /// necessary for the mask to be lowered to ArmSME. /// /// Example: /// /// BEFORE: /// ```mlir /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> /// %subMask = vector.extract %mask[2] /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> /// ``` /// /// AFTER: /// ```mlir /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> /// ``` struct FoldExtractFromVectorOfSMELikeCreateMasks : public OpRewritePattern<vector::ExtractOp> { … }; /// A vector type where no fixed dimension comes after a scalable dimension. bool isLegalVectorType(VectorType vType) { … } /// Lifts an illegal vector.transpose and vector.transfer_read to a /// memref.subview + memref.transpose, followed by a legal read. /// /// 'Illegal' here means a leading scalable dimension and a fixed trailing /// dimension, which has no valid lowering. /// /// The memref.transpose is metadata-only transpose that produces a strided /// memref, which eventually becomes a loop reading individual elements. /// /// Example: /// /// BEFORE: /// ```mlir /// %illegalRead = vector.transfer_read %memref[%a, %b] /// : memref<?x?xf32>, vector<[8]x4xf32> /// %legalType = vector.transpose %illegalRead, [1, 0] /// : vector<[8]x4xf32> to vector<4x[8]xf32> /// ``` /// /// AFTER: /// ```mlir /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1] /// : memref<?x?xf32> to memref<?x?xf32> /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0) /// : memref<?x?xf32> to memref<?x?xf32> /// %legalType = vector.transfer_read %transpose[%c0, %c0] /// : memref<?x?xf32>, vector<4x[8]xf32> /// ``` struct LiftIllegalVectorTransposeToMemory : public OpRewritePattern<vector::TransposeOp> { … }; /// A rewrite to turn unit dim transpose-like vector.shape_casts into /// vector.transposes. The shape_cast has to be from an illegal vector type to a /// legal one (as defined by isLegalVectorType). /// /// The reasoning for this is if we've got to this pass and we still have /// shape_casts of illegal types, then they likely will not cancel out. Turning /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to /// eliminate them. /// /// Example: /// /// BEFORE: /// ```mlir /// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32> /// ``` /// /// AFTER: /// ```mlir /// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> /// ``` struct ConvertIllegalShapeCastOpsToTransposes : public OpRewritePattern<vector::ShapeCastOp> { … }; /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use /// the ZA state. This workaround rewrite to support these transposes when ZA is /// available. /// /// Example: /// /// BEFORE: /// ```mlir /// %transpose = vector.transpose %vec, [1, 0] /// : vector<2x[4]xf32> to vector<[4]x2xf32> /// vector.transfer_write %transpose, %dest[%y, %x] /// : vector<[4]x2xf32>, memref<?x?xf32> /// ``` /// /// AFTER: /// ```mlir /// %0 = arm_sme.get_tile : vector<[4]x[4]xf32> /// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> /// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> /// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> /// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> /// %c4_vscale = arith.muli %vscale, %c4 : index /// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> /// vector.transfer_write %4, %dest[%y, %x], %mask /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} /// : vector<[4]x[4]xf32>, memref<?x?xf32> /// ``` /// /// Values larger than a single tile are supported via decomposition. struct LowerIllegalTransposeStoreViaZA : public OpRewritePattern<vector::TransferWriteOp> { … }; struct VectorLegalizationPass : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> { … }; } // namespace std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() { … }