//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" usingnamespacemlir; // Determine whether the value is defined to be zero. static bool isDefinedAsZero(Value val) { … } /// Replace a linalg.add with one operand the single user of a contraction, /// which has a zero-filled, "identity-mapped" destination and is dominated by /// the `other` operand, by the contraction with `other` as its dest. /// /// As an example, the following pseudo-code will be rewritten /// %cst = arith.constant 0.000000e+00 /// %empty = tensor.empty() /// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type /// %C = linalg.matmul ins(%A, %B) outs(%zeroed) /// %empty2 = tensor.empty() /// %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type /// %F = linalg.matmul ins(%D, %E) outs(%zeroed2) /// %out = linalg.add ins(%C, %F) outs(%empty) /// to: /// %cst = arith.constant 0.000000e+00 /// %empty = tensor.empty() /// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type /// %C = linalg.matmul ins(%A, %B) outs(%zeroed) /// %out = linalg.matmul ins(%D, %E) outs(%C) /// struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> { … }; void linalg::populateFoldAddIntoDestPatterns(RewritePatternSet &patterns) { … }