llvm/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp

//===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewrites that fuse 'arm_sme.outerproduct' operations
// into the 2-way or 4-way widening outerproduct operations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"

#define DEBUG_TYPE

namespace mlir::arm_sme {
#define GEN_PASS_DEF_OUTERPRODUCTFUSION
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
} // namespace mlir::arm_sme

usingnamespacemlir;
usingnamespacemlir::arm_sme;

namespace {

// Common match failure reasons.
static constexpr StringLiteral
    kMatchFailureNoAccumulator("no accumulator operand");
static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp(
    "defining op of accumulator must be 'arm_sme.outerproduct'");
static constexpr StringLiteral kMatchFailureInconsistentCombiningKind(
    "combining kind (add or sub) of outer products must match");
static constexpr StringLiteral kMatchFailureInconsistentMasking(
    "unsupported masking, either both outerproducts are masked "
    "or neither");
static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse(
    "outer product(s) not single use and cannot be removed, no benefit to "
    "fusing");

// An outer product is compatible if all of the following are true:
// - the result type matches `resultType`.
// - the defining operation of LHS is of the type `LhsExtOp`.
// - the defining operation of RHS is of the type `RhsExtOp`.
// - the input types of the defining operations are identical and match
//   `inputType`.
template <typename LhsExtOp, typename RhsExtOp = LhsExtOp>
static LogicalResult isCompatible(PatternRewriter &rewriter,
                                  arm_sme::OuterProductOp op,
                                  VectorType resultType, VectorType inputType) {}

// Fuse two 'arm_sme.outerproduct' operations that are chained via the
// accumulator into 2-way outer product operation.
//
// For example:
//
//  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
//  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
//  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>,
//                                               vector<[4]xf32>
//
//  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
//  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
//  %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>,
//                                                   vector<[4]xf32>
//
// Becomes:
//
//  %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
//  %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
//  %0 = arm_sme.fmopa_2way %a_packed, %b_packed
//    : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
class OuterProductFusion2Way
    : public OpRewritePattern<arm_sme::OuterProductOp> {};

// Fuse four 'arm_sme.outerproduct' operations that are chained via the
// accumulator into 4-way outer product operation.
class OuterProductFusion4Way
    : public OpRewritePattern<arm_sme::OuterProductOp> {};

// Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract).
//
// This transforms IR like:
//   %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
//   %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
// Into:
//   %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
//   %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
//
// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
// pass when the result is the input to an outer product.
struct SwapVectorExtractOfArithExtend
    : public OpRewritePattern<vector::ExtractOp> {};

// Same as above, but for vector.scalable.extract.
//
// This transforms IR like:
//   %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
//   %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
// Into:
//   %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
//   %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
//
// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
// pass when the result is the input to an outer product.
struct SwapVectorScalableExtractOfArithExtend
    : public OpRewritePattern<vector::ScalableExtractOp> {};

struct OuterProductFusionPass
    : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {};

} // namespace

void mlir::arm_sme::populateOuterProductFusionPatterns(
    RewritePatternSet &patterns) {}

std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() {}