//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This is intended to be a simple high-level (target-agnostic) matmul // transposition transformation. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::linalg; /// Pattern to replace /// /// linalg.matmul(a, b) /// /// with /// /// linalg.matmul_transpose_a(linalg.transpose(a), b) /// /// By default the LHS is transposed. Set `transposeLHS=false` to /// transpose RHS instead. FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp matmulOp, bool transposeLHS) { … } /// Pattern to replace /// /// linalg.batch_matmul(a, b) /// /// with /// /// linalg.batch_matmul_transpose_a(linalg.transpose(a), b) /// /// Only the non-batch dimensions are transposed. By default the LHS is /// transposed. Set `transposeLHS=false` to transpose RHS instead. FailureOr<Operation *> mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp, bool transposeLHS) { … } namespace { struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> { … }; struct TransposeBatchMatmul final : public OpRewritePattern<linalg::BatchMatmulOp> { … }; } // namespace void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns, bool transposeLHS) { … }