#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
class Operation;
class IRMapping;
class SymbolTableCollection;
namespace mesh {
ShardingArray;
ShardingArrayRef;
struct ShardingOption { … };
FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpResult result);
FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpOperand &opOperand);
namespace detail {
FailureOr<ShardingOption>
defaultGetShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings);
FailureOr<std::vector<MeshSharding>>
defaultGetShardingAnnotations(Operation *op,
const ShardingOption &shardingOption);
LogicalResult
defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
const ShardingOption &shardingOption);
}
void spmdizeFullyReplicatedOperation(Operation &op,
ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder);
}
}
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
#endif