//===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements rewrites that fuse 'arm_sme.outerproduct' operations // into the 2-way or 4-way widening outerproduct operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #define DEBUG_TYPE … namespace mlir::arm_sme { #define GEN_PASS_DEF_OUTERPRODUCTFUSION #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" } // namespace mlir::arm_sme usingnamespacemlir; usingnamespacemlir::arm_sme; namespace { // Common match failure reasons. static constexpr StringLiteral kMatchFailureNoAccumulator("no accumulator operand"); static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp( "defining op of accumulator must be 'arm_sme.outerproduct'"); static constexpr StringLiteral kMatchFailureInconsistentCombiningKind( "combining kind (add or sub) of outer products must match"); static constexpr StringLiteral kMatchFailureInconsistentMasking( "unsupported masking, either both outerproducts are masked " "or neither"); static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse( "outer product(s) not single use and cannot be removed, no benefit to " "fusing"); // An outer product is compatible if all of the following are true: // - the result type matches `resultType`. // - the defining operation of LHS is of the type `LhsExtOp`. // - the defining operation of RHS is of the type `RhsExtOp`. // - the input types of the defining operations are identical and match // `inputType`. template <typename LhsExtOp, typename RhsExtOp = LhsExtOp> static LogicalResult isCompatible(PatternRewriter &rewriter, arm_sme::OuterProductOp op, VectorType resultType, VectorType inputType) { … } // Fuse two 'arm_sme.outerproduct' operations that are chained via the // accumulator into 2-way outer product operation. // // For example: // // %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> // %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> // %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, // vector<[4]xf32> // // %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> // %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> // %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, // vector<[4]xf32> // // Becomes: // // %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16> // %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16> // %0 = arm_sme.fmopa_2way %a_packed, %b_packed // : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> class OuterProductFusion2Way : public OpRewritePattern<arm_sme::OuterProductOp> { … }; // Fuse four 'arm_sme.outerproduct' operations that are chained via the // accumulator into 4-way outer product operation. class OuterProductFusion4Way : public OpRewritePattern<arm_sme::OuterProductOp> { … }; // Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract). // // This transforms IR like: // %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32> // %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32> // Into: // %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8> // %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32> // // This enables outer product fusion in the `-arm-sme-outer-product-fusion` // pass when the result is the input to an outer product. struct SwapVectorExtractOfArithExtend : public OpRewritePattern<vector::ExtractOp> { … }; // Same as above, but for vector.scalable.extract. // // This transforms IR like: // %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32> // %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32> // Into: // %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8> // %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32> // // This enables outer product fusion in the `-arm-sme-outer-product-fusion` // pass when the result is the input to an outer product. struct SwapVectorScalableExtractOfArithExtend : public OpRewritePattern<vector::ScalableExtractOp> { … }; struct OuterProductFusionPass : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> { … }; } // namespace void mlir::arm_sme::populateOuterProductFusionPatterns( RewritePatternSet &patterns) { … } std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() { … }