//===- Specialize.cpp - linalg generic ops to named ops ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a method to specialize generic operations to named // operations. Conceptually it is the opposite of generalize.cpp. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/TypeID.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" namespace mlir { #define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE … #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) … #define REPLACE_UNARY_OP(NEWOP) … usingnamespacemlir; usingnamespacemlir::linalg; // Given a elementwise single binary linalg generic op, checks whether the // binary op accesses operands as swapped. e.g. // this differentiates between a linalg-generic body that contains: // ^bb0(%a: f32, %b: f32, %c : f32): // %0 = arith.subf %a, %b : f32 // linalg.yield %0: f32 // against: // ^bb0(%a: f32, %b: f32, %c : f32): // %0 = arith.subf %b, %a : f32 // linalg.yield %0: f32 // Former is linalg.sub(a,b), latter is linalg.sub(b,a). static bool areBinOpsSwapped(GenericOp genericOp) { … } //===----------------------------------------------------------------------===// // Specialize linalg generic to matmul variants. //===----------------------------------------------------------------------===// /// Identifies linalg.generic that is essentially named op of the form: // ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? ` // // It is possible that a linalg.generic may be implementing a matmul but not // in a straight-forward way e.g. below is matrix multiply over some slice // ``` // %0 = linalg.generic { // indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>, // affine_map<(d0, d1, d2) -> (d0, 5, d2)>, // affine_map<(d0, d1, d2) -> (d2, d1, 13)>], // iterator_types = ["parallel", "parallel", "parallel"]} // ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>) // outs(%C : tensor<20x20x20xf32>) { // ^bb0(%a: f32, %b: f32, %c : f32): // %mul = arith.mulf %a, %b : f32 // %add = arith.addf %mul, %c : f32 // linalg.yield %add : f32 // } -> tensor<20x20x20xf32> // ``` // It is not possible to represent above as named op. // e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is // not the same as linalg.generic above. namespace { enum class IndexMatchResult { … }; // Checks whether the input Affine `map` contains two consecutive dims that // can be interpreted as accessing a 2D matrix. It is assumed that the row // column dimension are adjacent axis (in this order) and start at // `rowDimIdx` in the input map. // // e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check // whether the map of A is identity (match), transposed, or something // completely different (mis-match). Similar for B and C. static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx, unsigned expectedPosOfRowDim, unsigned expectedPosOfColDim) { … } // Replaces genericOp with `NamedOpTy` op, supplied as a template arg. // All the variants expressed as pseudo regular expression: // `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?` // have same number of ins/out, so its easy to stamp different versions. template <typename NamedOpTy> static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) { … } // Converts linalg.generic to named linalg.*matmul* where possible. static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter, GenericOp genericOp) { … } } // namespace //===----------------------------------------------------------------------===// // Categorize linalg generic to named op where possible. //===----------------------------------------------------------------------===// FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp) { … } namespace { struct LinalgSpecializeGenericOpsPass : public impl::LinalgSpecializeGenericOpsPassBase< LinalgSpecializeGenericOpsPass> { … }; } // namespace void LinalgSpecializeGenericOpsPass::runOnOperation() { … } void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns( RewritePatternSet &patterns) { … }