#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Value.h"
namespace mlir {
class Operation;
class IRMapping;
class SymbolTableCollection;
namespace mesh {
ShardingArray getMeshAxisAssignmentForLoopIterators(
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<AffineMap> indexingMaps);
bool isAtLeastOneReductionIteratorSharded(
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
SmallVector<MeshAxis> getReductionMeshAxes(
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
void spmdizeTriviallyShardableOperation(Operation &op,
ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder);
template <typename Op>
struct IndependentParallelIteratorDomainShardingInterface
: public ShardingInterface::ExternalModel<
IndependentParallelIteratorDomainShardingInterface<Op>, Op> { … };
template <typename ElemwiseOp>
struct ElementwiseShardingInterface
: public ShardingInterface::ExternalModel<
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> { … };
}
}
#endif