//===- TransposeConv2D.cpp - Convolution transposition -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/RWMutex.h" #include <memory> #include <numeric> namespace mlir { namespace linalg { namespace { // clang-format off /// Convolution converter that applies the following rewrite: /// /// Before: /// /// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, /// strides = dense<2> : tensor<2xi64>} /// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) /// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> /// /// After: /// /// %cst = arith.constant 0.000000e+00 : f32 /// %0 = tensor.empty() : tensor<2x2x6x8xf32> /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32> /// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>) /// permutation = [1, 2, 3, 0] /// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} /// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>) /// -> tensor<1x2x2x8xf32> /// /// with an analogous example for the quantized case. // clang-format on template <typename FHWCConvOp, typename HWCFConvOp> FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter, FHWCConvOp op) { … } template <typename FHWCConvOp, typename HWCFConvOp> class ConvConverter : public OpRewritePattern<FHWCConvOp> { … }; } // namespace FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op) { … } FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcQOp op) { … } void populateTranposeConv2DPatterns(RewritePatternSet &patterns) { … } } // namespace linalg } // namespace mlir