llvm/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implement Winograd Conv2D algorithm. The implementation is based on the
// paper: Fast Algorithms for Convolutional Neural Networks
// (https://arxiv.org/abs/1509.09308)
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/MathExtras.h"

namespace mlir {
namespace linalg {

namespace {

// clang-format off
/// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
/// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
/// m is the output dimension and r is the filter dimension, is
///
/// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
///
/// g is filter and d is input data. We need to prepare 6 constant
/// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
///
/// The following tables define these constant transformation matrices for
/// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
constexpr float G_2x2_3x3[] =;

constexpr float GT_2x2_3x3[] =;

constexpr float BT_2x2_3x3[] =;

constexpr float B_2x2_3x3[] =;

constexpr float AT_2x2_3x3[] =;

constexpr float A_2x2_3x3[] =;

constexpr float G_4x4_3x3[] =;

constexpr float GT_4x4_3x3[] =;

constexpr float BT_4x4_3x3[] =;

constexpr float B_4x4_3x3[] =;

constexpr float AT_4x4_3x3[] =;

constexpr float A_4x4_3x3[] =;

constexpr float G_2x2_5x5[] =;

constexpr float GT_2x2_5x5[] =;

constexpr float BT_2x2_5x5[] =;

constexpr float B_2x2_5x5[] =;

constexpr float AT_2x2_5x5[] =;

constexpr float A_2x2_5x5[] =;
// clang-format on

TransformMapKeyTy;

/// We use F(m, r) to define the size of minimal filtering algorithms.
/// m is the output dimension and r is the filter dimension. We can get
/// the input dimension, alpha, from the formula, alpha = m + r - 1.
///
/// For example, when m = 2 and r = 3, we know its input size is 4.
/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
/// 2x2 output result.
constexpr TransformMapKeyTy F_2_3{};
constexpr TransformMapKeyTy F_4_3{};
constexpr TransformMapKeyTy F_2_5{};

/// Structure to keep information of constant transform matrices.
struct TransformMatrix {};

/// Utility function to convert constant array to arith.constant Value.
Value create2DTransformMatrix(OpBuilder &builder, Location loc,
                              TransformMatrix transform, Type type) {}

/// Extract height x width data from 4D tensors.
Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
                          Value loopNorFIndex, Value loopCorFIndex,
                          Value heightOffset, Value widthOffset,
                          int64_t extractHeight, int64_t extractWidth,
                          int64_t loopNorFIdx, int64_t loopCorFIdx,
                          int64_t heightIdx, int64_t widthIdx) {}

/// Extract height x width data from 6D tensors.
Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
                          Value tileHIndex, Value tileWIndex,
                          Value loopNorFIndex, Value loopCorFIndex,
                          int64_t tileHIdx, int64_t tileWIdx,
                          int64_t loopNorFIdx, int64_t loopCorFIdx,
                          int64_t heightIdx, int64_t widthIdx) {}

/// Insert transformed height x width data to 4D tensors which it is
/// extracted from.
Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
                       Value dest, Value loopNorFIndex, Value loopCorFIndex,
                       Value heightOffset, Value widthOffset, int64_t height,
                       int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
                       int64_t heightIdx, int64_t widthIdx) {}

/// Insert transformed height x width data to 6D tensors which it is
/// extracted from.
Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
                       Value dest, Value tileHIndex, Value tileWIndex,
                       Value loopNorFIndex, Value loopCorFIndex, int64_t height,
                       int64_t width, int64_t tileHIdx, int64_t tileWIdx,
                       int64_t loopNorFIdx, int64_t loopCorFIdx,
                       int64_t heightIdx, int64_t widthIdx) {}

/// This function transforms the filter. The data layout of the filter is FHWC.
/// The transformation matrix is 2-dimension. We need to extract H x W from
/// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
/// After the transformation, we get
///
/// scf.for %f = lo_f to hi_f step 1
///   scf.for %c = lo_c to hi_c step 1
///     %extracted = extract filter<h x w> from filter<f x h x w x c>
///     %ret = linalg.matmul G, %extracted
///     %ret = linalg.matmul %ret, GT
///     %inserted = insert %ret into filter<h x w x c x f>
Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
                      Value retValue, int64_t m, int64_t r,
                      bool leftTransform = true, bool rightTransform = true) {}

/// This function transforms the input. The data layout of the input is NHWC.
/// The transformation matrix is 2-dimension. We need to extract H x W from
/// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
/// After the transformation, we get
///
/// scf.for %h = 0 to tileH step 1
///   scf.for %w = 0 to tileW step 1
///     scf.for %n = 0 to N step 1
///       scf.for %c = 0 to C step 1
///         %extracted = extract %extracted<alphaH x alphaW> from
///                              %input<N x H x W x C>
///                              at [%n, (%h x m), (%w x m), %c]
///         %ret = linalg.matmul BT, %extracted
///         %ret = linalg.matmul %ret, B
///         %inserted = insert %ret<alphaH x alphaW> into
///                            %output<alphaH x alphaW x tileH x tileW x N x C>
///                            at [0, 0, %h, %w, %n, %c]
Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
                     Value retValue, int64_t m, int64_t r,
                     bool leftTransform = true, bool rightTransform = true) {}

/// This function generates linalg.batch_matmul to multiply input with filter.
/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
/// tileH x tileW x H x W data as the 1-dimensional data array. That is to
/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
/// way, we can convert 6-dimensional inputs to 3-dimensional representation
/// that is suitable for linalg.batch_matmul.
///
/// Batched matmul will do the matrix multiply with the reduction on channel.
///
/// We get
///
/// %collapsed_input = tensor.collapse_shape %input
/// %collapsed_filter = tensor.collapse_shape %filter
/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
/// %expanded_ret = tensor.expand_shape %ret
///
/// After this function, we get return value with data layout
/// (tileH, tileW, H, W, N, F).
static Value matrixMultiply(RewriterBase &rewriter, Location loc,
                            Value transformedFilter, Value transformedInput,
                            Type outputElementType) {}

/// This function transforms the output. The data layout of the output is HWNF.
/// The transformation matrix is 2-dimension. We need to extract H x W from
/// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
/// After the transformation, we get
///
/// scf.for %h = 0 to tileH step 1
///   scf.for %w = 0 to tileW step 1
///     scf.for %n = 0 to N step 1
///       scf.for %f = 0 to F step 1
///         %extracted = extract %extracted<alphaH x alphaW> from
///                              %input<alphaH x alphaW x tileH x tileW x N x F>
///                              at [0, 0, %h, %w, %n, %f]
///         %ret = linalg.matmul AT, %extracted
///         %ret = linalg.matmul %ret, A
///         %inserted = insert %ret<alphaH x alphaW> into
///                            output<N x H x W x F>
///                            at [%n, (%h x m), (%w x m), %f]
Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
                      Value output, int64_t m, int64_t r,
                      bool leftTransform = true, bool rightTransform = true) {}

/// Create an empty tensor with alignedType and insert the value into the
/// created empty tensor with aligned size.
static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
                                Value value, ArrayRef<int64_t> alignedShape) {}

/// Extract sub-tensor with extractedType from value.
static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
                                      Value value,
                                      RankedTensorType extractedType) {}

/// Utility function to check all values in the attribute are 1.
static bool hasAllOneValues(DenseIntElementsAttr attr) {}

/// A helper function to convert linalg.conv_2d_nhwc_fhwc to
/// linalg.winograd_*_transform ops.
static FailureOr<Operation *>
winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
                     int64_t m, int64_t r) {}

/// A helper function to decompose linalg.winograd_filter_transform.
FailureOr<Operation *>
decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
                                       linalg::WinogradFilterTransformOp op) {}

/// A helper function to decompose linalg.winograd_input_transform.
FailureOr<Operation *>
decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
                                      linalg::WinogradInputTransformOp op) {}

/// A helper function to decompose linalg.winograd_output_transform.
FailureOr<Operation *>
decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
                                       linalg::WinogradOutputTransformOp op) {}

/// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
class DecomposeWinogradFilterTransform final
    : public OpRewritePattern<linalg::WinogradFilterTransformOp> {};

/// A rewrite pattern to decompose linalg.winograd_input_transform operations.
class DecomposeWinogradInputTransform final
    : public OpRewritePattern<linalg::WinogradInputTransformOp> {};

/// A rewrite pattern to decompose linalg.winograd_output_transform operations.
class DecomposeWinogradOutputTransform final
    : public OpRewritePattern<linalg::WinogradOutputTransformOp> {};

/// A rewrite pattern for Winograd Conv2D algorithm.
class WinogradConv2DNhwcFhwc final
    : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {};
} // end anonymous namespace

//===----------------------------------------------------------------------===//
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
                                      int64_t r) {}

FailureOr<Operation *>
decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
                                   linalg::WinogradFilterTransformOp op) {}

FailureOr<Operation *>
decomposeWinogradInputTransformOp(RewriterBase &rewriter,
                                  linalg::WinogradInputTransformOp op) {}

FailureOr<Operation *>
decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
                                   linalg::WinogradOutputTransformOp op) {}

void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                    int64_t r) {}

void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) {}

} // end namespace linalg
} // end namespace mlir