//===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include <optional> usingnamespacemlir; usingnamespacemlir::linalg; namespace { /// Pattern to decompose a GenericOp that has more than two statements /// into one GenericOp with the first statement (i.e. peeled operation), and /// a second GenericOp with the remaining statements (i.e. residual operations). /// - The result of the first GenericOp has the same shape as the iteration /// space of the GenericOp. The body of the op yields as many values as the /// original op plus all the results of the peeled operation. /// - The second GenericOp has as many operands as the original operation plus /// all the results of the first Generic Op. It has the same number of yields as /// the original op. /// - If the result of the peeled operation was yielded by the original /// GenericOp the uses of the corresponding results will be replaced with the /// result of the first GenericOp created. /// /// Example /// /// ```mlir /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) /// outs(%init0, %init1 : ...) { /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...): /// %0 = <s0> %b0, %b1 : ... /// %1 = <s1> %0, %b2 : ... /// linalg.yield %0, %1 : ... /// } -> (..., ...) /// return %result#0, %result#1 /// ``` /// /// gets split into /// /// ```mlir /// %init = tensor.empty ... /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) /// outs(%init0, %init1, %init : ...) /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): /// %0 = <s0> %b0, %b1 : ... /// linalg.yield %0, %..., %0 : ... /// } -> (..., ..., ...) /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...) /// outs(%init0, %init1 : ...) { /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): /// %1 = <s1> %b3, %b2 : ... /// linalg.yield %..., %1 : ... /// } -> (..., ...) /// return %op0#0, %op1#1 /// ``` /// /// After canonicalization this is expected to be /// /// ```mlir /// %init = tensor.empty ... /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...) /// outs(%init : ...) /// ^bb0(%b0: ... , %b1: ... , %b2: ...): /// %0 = <s0> %b0, %b1 : ... /// linalg.yield %0 : ... /// } -> ... /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...) /// outs(%init1 : ...) { /// ^bb0(%b0: ... , %b1: ... , %b2: ...): /// %1 = <s1> %b1, %b0 : ... /// linalg.yield %..., %1 : ... /// } -> ... /// return %op0, %op1 /// ``` struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> { … }; } // namespace /// Helper method to compute the range of a generic op. static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b, GenericOp op) { … } /// Helper method to permute the list of `values` based on the `map`. SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values, AffineMap map) { … } /// Get zero value for an element type. static Value getZero(OpBuilder &b, Location loc, Type elementType) { … } GenericOp DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, PatternRewriter &rewriter) const { … } GenericOp DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, GenericOp peeledGenericOp, PatternRewriter &rewriter) const { … } LogicalResult DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const { … } void mlir::linalg::populateDecomposeLinalgOpsPattern( RewritePatternSet &patterns, bool removeDeadArgsAndResults) { … }