#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
#include <utility>
#define DEBUG_TYPE …
#define DBGS() …
usingnamespacemlir;
usingnamespacemlir::mesh;
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
static LogicalResult
checkOperandAffineExprRecursively(AffineExpr expr,
SmallVectorImpl<bool> &seenIds) { … }
static FailureOr<llvm::SmallSet<unsigned, 2>>
checkOperandAffineExpr(AffineExpr expr, unsigned numDims) { … }
template <typename T>
SmallVector<MeshAxesAttr>
fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) { … }
FailureOr<std::pair<bool, MeshSharding>>
mesh::getMeshSharding(OpResult result) { … }
FailureOr<std::pair<bool, MeshSharding>>
mesh::getMeshSharding(OpOperand &opOperand) { … }
LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { … }
void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { … }
namespace {
static LogicalResult fillShardingOption(Operation *op,
ShardingOption &shardingOption,
FlatSymbolRefAttr mesh,
ArrayRef<MeshAxis> meshAxes,
unsigned loopIdx) { … }
}
FailureOr<ShardingOption>
mesh::detail::defaultGetShardingOption(Operation *op,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings) { … }
MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
ArrayRef<ReductionKind> reductionLoopKinds) { … }
static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
const ShardingOption &shardingOption,
AffineMap map) { … }
FailureOr<std::vector<MeshSharding>>
mesh::detail::defaultGetShardingAnnotations(
Operation *op, const ShardingOption &shardingOption) { … }
static LogicalResult addShardOp(OpBuilder &b, OpResult result,
const ShardingOption &shardingOption,
AffineMap map,
ArrayRef<utils::IteratorType> loopTypes,
ArrayRef<ReductionKind> reductionLoopKinds) { … }
static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
const ShardingOption &shardingOption,
AffineMap map) { … }
LogicalResult mesh::detail::defaultAddShardingAnnotations(
Operation *op, OpBuilder &b, const ShardingOption &shardingOption) { … }
#ifndef NDEBUG
static bool
isValueCompatibleWithFullReplicationSharding(Value value,
MeshSharding sharding) {
if (isa<RankedTensorType>(value.getType())) {
return sharding && isFullReplication(sharding);
}
return !sharding;
}
template <typename ValueRange, typename MeshShardingRage>
static bool
areValuesCompatibleWithFullReplicationShardings(ValueRange &&values,
MeshShardingRage &&shardings) {
if (std::size(values) != std::size(shardings)) {
return false;
}
return llvm::all_of(
llvm::zip_equal(std::forward<ValueRange>(values),
std::forward<MeshShardingRage>(shardings)),
[](auto valueAndSharding) {
return isValueCompatibleWithFullReplicationSharding(
std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
});
}
#endif
void mesh::spmdizeFullyReplicatedOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable, OpBuilder &builder) { … }
static void updateMeshAxisAssignmentForLoopIterators(
ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
SmallVector<std::optional<SmallVector<MeshAxis>>>
&meshAxesAssignmentForLoopIterators) { … }
ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<AffineMap> indexingMaps) { … }
bool mesh::isAtLeastOneReductionIteratorSharded(
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { … }
SmallVector<MeshAxis> mesh::getReductionMeshAxes(
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { … }
void mesh::spmdizeTriviallyShardableOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable, OpBuilder &builder) { … }