llvm/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h

//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H

#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/EndomorphismSimplification.h"
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <iterator>
#include <memory>
#include <utility>

namespace mlir {

class SymbolTableCollection;

namespace mesh {

// If we have an algebraic op like "+" and a summing all-reduce,
// `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to
// `all_reduce_sum(x + y)`.
//
// Another example with `min`.
// `min(all_reduce_min(x), all_reduce_min(y))` will be transformed to
// `all_reduce_min(min(x, y))`.
//
// Works only with algebraic ops that have all their operands relevant
// to the all-reduce endomorphism.
// Will not work with some op `f(x, y, z)` where only `x` and `y` form
// the algebraic structure.
template <typename AlgebraicOp>
void populateAllReduceEndomorphismSimplificationPatterns(
    RewritePatternSet &patterns, ReductionKind reduction) {}

// It is invalid to change ops that declare symbols during the application of
// these patterns, because symbolTableCollection is used to cache them.
void populateSimplificationPatterns(
    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
void populateFoldingPatterns(RewritePatternSet &patterns,
                             SymbolTableCollection &symbolTableCollection);

} // namespace mesh
} // namespace mlir

#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H