llvm/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp

//===- TransposeConv2D.cpp - Convolution transposition  -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/RWMutex.h"
#include <memory>
#include <numeric>

namespace mlir {
namespace linalg {
namespace {
// clang-format off
/// Convolution converter that applies the following rewrite:
///
/// Before:
///
///   %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
///                                               strides = dense<2> : tensor<2xi64>}
///      ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
///     outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
///
/// After:
///
///    %cst = arith.constant 0.000000e+00 : f32
///    %0 = tensor.empty() : tensor<2x2x6x8xf32>
///    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
///    %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
///                  permutation = [1, 2, 3, 0]
///    %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
///         ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
///         -> tensor<1x2x2x8xf32>
///
/// with an analogous example for the quantized case.
// clang-format on
template <typename FHWCConvOp, typename HWCFConvOp>
FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
                                             FHWCConvOp op) {}

template <typename FHWCConvOp, typename HWCFConvOp>
class ConvConverter : public OpRewritePattern<FHWCConvOp> {};
} // namespace

FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
                                       linalg::Conv2DNhwcFhwcOp op) {}

FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
                                       linalg::Conv2DNhwcFhwcQOp op) {}

void populateTranposeConv2DPatterns(RewritePatternSet &patterns) {}
} // namespace linalg
} // namespace mlir