//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H #define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/EndomorphismSimplification.h" #include "llvm/Support/Casting.h" #include <algorithm> #include <iterator> #include <memory> #include <utility> namespace mlir { class SymbolTableCollection; namespace mesh { // If we have an algebraic op like "+" and a summing all-reduce, // `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to // `all_reduce_sum(x + y)`. // // Another example with `min`. // `min(all_reduce_min(x), all_reduce_min(y))` will be transformed to // `all_reduce_min(min(x, y))`. // // Works only with algebraic ops that have all their operands relevant // to the all-reduce endomorphism. // Will not work with some op `f(x, y, z)` where only `x` and `y` form // the algebraic structure. template <typename AlgebraicOp> void populateAllReduceEndomorphismSimplificationPatterns( RewritePatternSet &patterns, ReductionKind reduction) { … } // It is invalid to change ops that declare symbols during the application of // these patterns, because symbolTableCollection is used to cache them. void populateSimplificationPatterns( RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection); void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection); } // namespace mesh } // namespace mlir #endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H