llvm/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

//===- Specialize.cpp - linalg generic ops to named ops  ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a method to specialize generic operations to named
// operations. Conceptually it is the opposite of generalize.cpp.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE

#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)

#define REPLACE_UNARY_OP(NEWOP)

usingnamespacemlir;
usingnamespacemlir::linalg;

// Given a elementwise single binary linalg generic op, checks whether the
// binary op accesses operands as swapped. e.g.
// this differentiates between a linalg-generic body that contains:
//    ^bb0(%a: f32, %b: f32, %c : f32):
//         %0 = arith.subf %a, %b : f32
//         linalg.yield %0: f32
// against:
//    ^bb0(%a: f32, %b: f32, %c : f32):
//         %0 = arith.subf %b, %a : f32
//         linalg.yield %0: f32
// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
static bool areBinOpsSwapped(GenericOp genericOp) {}

//===----------------------------------------------------------------------===//
// Specialize linalg generic to matmul variants.
//===----------------------------------------------------------------------===//
/// Identifies linalg.generic that is essentially named op of the form:
//    ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
//
// It is possible that a linalg.generic may be implementing a matmul but not
// in a straight-forward way e.g. below is matrix multiply over some slice
// ```
//  %0 = linalg.generic {
//          indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
//                           affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
//                           affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
//          iterator_types = ["parallel", "parallel", "parallel"]}
//          ins(%A, %B : tensor<20x20x20xf32>,  tensor<20x20x20xf32>)
//          outs(%C : tensor<20x20x20xf32>) {
//             ^bb0(%a: f32, %b: f32, %c : f32):
//                %mul = arith.mulf %a, %b : f32
//                %add = arith.addf %mul, %c : f32
//                linalg.yield %add : f32
//       } -> tensor<20x20x20xf32>
// ```
// It is not possible to represent above as named op.
// e.g. linalg.batch_matmul(%A, %B :  tensor<20x20x20xf32>, ...) is
// not  the same as linalg.generic above.
namespace {
enum class IndexMatchResult {};

// Checks whether the input Affine `map` contains two consecutive dims that
// can be interpreted as accessing a 2D matrix. It is assumed that the row
// column dimension are adjacent axis (in this order) and start at
// `rowDimIdx` in the input map.
//
//  e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
//  whether the map of A is identity (match), transposed, or something
//  completely different (mis-match). Similar for B and C.
static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
                                        unsigned expectedPosOfRowDim,
                                        unsigned expectedPosOfColDim) {}

// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
//  All the variants expressed as pseudo regular expression:
//      `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
//  have same number of ins/out, so its easy to stamp different versions.
template <typename NamedOpTy>
static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {}

// Converts linalg.generic to named linalg.*matmul* where possible.
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
                                                        GenericOp genericOp) {}

} // namespace

//===----------------------------------------------------------------------===//
// Categorize linalg generic to named op where possible.
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
                                                      GenericOp genericOp) {}

namespace {
struct LinalgSpecializeGenericOpsPass
    : public impl::LinalgSpecializeGenericOpsPassBase<
          LinalgSpecializeGenericOpsPass> {};
} // namespace

void LinalgSpecializeGenericOpsPass::runOnOperation() {}

void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
    RewritePatternSet &patterns) {}