#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <vector>
namespace mlir {
namespace mesh {
#define GEN_PASS_DEF_SHARDINGPROPAGATION
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
}
}
#define DEBUG_TYPE …
#define DBGS() …
usingnamespacemlir;
usingnamespacemlir::mesh;
enum class ReshardingRquirementKind { … };
#ifdef LLVM_DEBUG
template <typename T>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const SmallVector<T> &vec);
template <typename... Ts>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const std::tuple<Ts...> &t);
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
ReshardingRquirementKind v);
template <typename Stream, typename Range>
static Stream &printRange(Stream &stream, Range &&range) { … }
template <typename T>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const SmallVector<T> &vec) { … }
[[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const ShardingOption &v) { … }
template <typename Stream, typename... Ts, size_t... Is>
static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
std::index_sequence<Is...>) { … }
template <typename... Ts>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const std::tuple<Ts...> &t) { … }
[[maybe_unused]] static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) { … }
#endif
static SmallVector<std::vector<MeshSharding>>
getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
ArrayRef<MeshSharding> optionalShardings) { … }
ReshardingRquirementKind getReshardingRquirementKind(
Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) { … }
static FailureOr<ShardingOption> selectShardingOption(
ShardingInterface shardingOp,
ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) { … }
static LogicalResult visitOp(Operation *op, OpBuilder &builder) { … }
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> { … };