//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Implement Winograd Conv2D algorithm. The implementation is based on the // paper: Fast Algorithms for Convolutional Neural Networks // (https://arxiv.org/abs/1509.09308) // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/MathExtras.h" namespace mlir { namespace linalg { namespace { // clang-format off /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r), /// m is the output dimension and r is the filter dimension, is /// /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A /// /// g is filter and d is input data. We need to prepare 6 constant /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula. /// /// The following tables define these constant transformation matrices for /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5) constexpr float G_2x2_3x3[] = …; constexpr float GT_2x2_3x3[] = …; constexpr float BT_2x2_3x3[] = …; constexpr float B_2x2_3x3[] = …; constexpr float AT_2x2_3x3[] = …; constexpr float A_2x2_3x3[] = …; constexpr float G_4x4_3x3[] = …; constexpr float GT_4x4_3x3[] = …; constexpr float BT_4x4_3x3[] = …; constexpr float B_4x4_3x3[] = …; constexpr float AT_4x4_3x3[] = …; constexpr float A_4x4_3x3[] = …; constexpr float G_2x2_5x5[] = …; constexpr float GT_2x2_5x5[] = …; constexpr float BT_2x2_5x5[] = …; constexpr float B_2x2_5x5[] = …; constexpr float AT_2x2_5x5[] = …; constexpr float A_2x2_5x5[] = …; // clang-format on TransformMapKeyTy; /// We use F(m, r) to define the size of minimal filtering algorithms. /// m is the output dimension and r is the filter dimension. We can get /// the input dimension, alpha, from the formula, alpha = m + r - 1. /// /// For example, when m = 2 and r = 3, we know its input size is 4. /// The Conv2D will operate on 4x4 input data with 3x3 filter and get /// 2x2 output result. constexpr TransformMapKeyTy F_2_3{ … }; constexpr TransformMapKeyTy F_4_3{ … }; constexpr TransformMapKeyTy F_2_5{ … }; /// Structure to keep information of constant transform matrices. struct TransformMatrix { … }; /// Utility function to convert constant array to arith.constant Value. Value create2DTransformMatrix(OpBuilder &builder, Location loc, TransformMatrix transform, Type type) { … } /// Extract height x width data from 4D tensors. Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source, Value loopNorFIndex, Value loopCorFIndex, Value heightOffset, Value widthOffset, int64_t extractHeight, int64_t extractWidth, int64_t loopNorFIdx, int64_t loopCorFIdx, int64_t heightIdx, int64_t widthIdx) { … } /// Extract height x width data from 6D tensors. Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source, Value tileHIndex, Value tileWIndex, Value loopNorFIndex, Value loopCorFIndex, int64_t tileHIdx, int64_t tileWIdx, int64_t loopNorFIdx, int64_t loopCorFIdx, int64_t heightIdx, int64_t widthIdx) { … } /// Insert transformed height x width data to 4D tensors which it is /// extracted from. Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source, Value dest, Value loopNorFIndex, Value loopCorFIndex, Value heightOffset, Value widthOffset, int64_t height, int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx, int64_t heightIdx, int64_t widthIdx) { … } /// Insert transformed height x width data to 6D tensors which it is /// extracted from. Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source, Value dest, Value tileHIndex, Value tileWIndex, Value loopNorFIndex, Value loopCorFIndex, int64_t height, int64_t width, int64_t tileHIdx, int64_t tileWIdx, int64_t loopNorFIdx, int64_t loopCorFIdx, int64_t heightIdx, int64_t widthIdx) { … } /// This function transforms the filter. The data layout of the filter is FHWC. /// The transformation matrix is 2-dimension. We need to extract H x W from /// FHWC first. We need to generate 2 levels of loops to iterate on F and C. /// After the transformation, we get /// /// scf.for %f = lo_f to hi_f step 1 /// scf.for %c = lo_c to hi_c step 1 /// %extracted = extract filter<h x w> from filter<f x h x w x c> /// %ret = linalg.matmul G, %extracted /// %ret = linalg.matmul %ret, GT /// %inserted = insert %ret into filter<h x w x c x f> Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, Value retValue, int64_t m, int64_t r, bool leftTransform = true, bool rightTransform = true) { … } /// This function transforms the input. The data layout of the input is NHWC. /// The transformation matrix is 2-dimension. We need to extract H x W from /// NHWC first. We need to generate 2 levels of loops to iterate on N and C. /// After the transformation, we get /// /// scf.for %h = 0 to tileH step 1 /// scf.for %w = 0 to tileW step 1 /// scf.for %n = 0 to N step 1 /// scf.for %c = 0 to C step 1 /// %extracted = extract %extracted<alphaH x alphaW> from /// %input<N x H x W x C> /// at [%n, (%h x m), (%w x m), %c] /// %ret = linalg.matmul BT, %extracted /// %ret = linalg.matmul %ret, B /// %inserted = insert %ret<alphaH x alphaW> into /// %output<alphaH x alphaW x tileH x tileW x N x C> /// at [0, 0, %h, %w, %n, %c] Value inputTransform(RewriterBase &rewriter, Location loc, Value input, Value retValue, int64_t m, int64_t r, bool leftTransform = true, bool rightTransform = true) { … } /// This function generates linalg.batch_matmul to multiply input with filter. /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat /// tileH x tileW x H x W data as the 1-dimensional data array. That is to /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this /// way, we can convert 6-dimensional inputs to 3-dimensional representation /// that is suitable for linalg.batch_matmul. /// /// Batched matmul will do the matrix multiply with the reduction on channel. /// /// We get /// /// %collapsed_input = tensor.collapse_shape %input /// %collapsed_filter = tensor.collapse_shape %filter /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter /// %expanded_ret = tensor.expand_shape %ret /// /// After this function, we get return value with data layout /// (tileH, tileW, H, W, N, F). static Value matrixMultiply(RewriterBase &rewriter, Location loc, Value transformedFilter, Value transformedInput, Type outputElementType) { … } /// This function transforms the output. The data layout of the output is HWNF. /// The transformation matrix is 2-dimension. We need to extract H x W from /// HWNF first. We need to generate 2 levels of loops to iterate on N and F. /// After the transformation, we get /// /// scf.for %h = 0 to tileH step 1 /// scf.for %w = 0 to tileW step 1 /// scf.for %n = 0 to N step 1 /// scf.for %f = 0 to F step 1 /// %extracted = extract %extracted<alphaH x alphaW> from /// %input<alphaH x alphaW x tileH x tileW x N x F> /// at [0, 0, %h, %w, %n, %f] /// %ret = linalg.matmul AT, %extracted /// %ret = linalg.matmul %ret, A /// %inserted = insert %ret<alphaH x alphaW> into /// output<N x H x W x F> /// at [%n, (%h x m), (%w x m), %f] Value outputTransform(RewriterBase &rewriter, Location loc, Value value, Value output, int64_t m, int64_t r, bool leftTransform = true, bool rightTransform = true) { … } /// Create an empty tensor with alignedType and insert the value into the /// created empty tensor with aligned size. static Value padToAlignedTensor(RewriterBase &rewriter, Location loc, Value value, ArrayRef<int64_t> alignedShape) { … } /// Extract sub-tensor with extractedType from value. static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, Value value, RankedTensorType extractedType) { … } /// Utility function to check all values in the attribute are 1. static bool hasAllOneValues(DenseIntElementsAttr attr) { … } /// A helper function to convert linalg.conv_2d_nhwc_fhwc to /// linalg.winograd_*_transform ops. static FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, int64_t m, int64_t r) { … } /// A helper function to decompose linalg.winograd_filter_transform. FailureOr<Operation *> decomposeWinogradFilterTransformHelper(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op) { … } /// A helper function to decompose linalg.winograd_input_transform. FailureOr<Operation *> decomposeWinogradInputTransformHelper(RewriterBase &rewriter, linalg::WinogradInputTransformOp op) { … } /// A helper function to decompose linalg.winograd_output_transform. FailureOr<Operation *> decomposeWinogradOutputTransformHelper(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op) { … } /// A rewrite pattern to decompose linalg.winograd_filter_transform operations. class DecomposeWinogradFilterTransform final : public OpRewritePattern<linalg::WinogradFilterTransformOp> { … }; /// A rewrite pattern to decompose linalg.winograd_input_transform operations. class DecomposeWinogradInputTransform final : public OpRewritePattern<linalg::WinogradInputTransformOp> { … }; /// A rewrite pattern to decompose linalg.winograd_output_transform operations. class DecomposeWinogradOutputTransform final : public OpRewritePattern<linalg::WinogradOutputTransformOp> { … }; /// A rewrite pattern for Winograd Conv2D algorithm. class WinogradConv2DNhwcFhwc final : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> { … }; } // end anonymous namespace //===----------------------------------------------------------------------===// FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r) { … } FailureOr<Operation *> decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op) { … } FailureOr<Operation *> decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op) { … } FailureOr<Operation *> decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op) { … } void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r) { … } void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) { … } } // end namespace linalg } // end namespace mlir