#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include <iterator>
#include <optional>
#include <tuple>
#include <type_traits>
namespace mlir::mesh {
template <typename SourceAxes, typename TargetAxes>
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
const TargetAxes &targetAxes) { … }
static std::tuple<TypedValue<ShapedType>, MeshSharding>
handlePartialAxesDuringResharding(OpBuilder &builder,
MeshSharding sourceSharding,
MeshSharding targetSharding,
TypedValue<ShapedType> sourceShard) { … }
static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
MeshSharding sourceSharding,
int64_t splitTensorAxis,
MeshAxis splitMeshAxis) { … }
static std::tuple<TypedValue<ShapedType>, MeshSharding>
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshSharding sourceSharding,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) { … }
static std::optional<std::tuple<int64_t, MeshAxis>>
detectSplitLastAxisInResharding(MeshSharding sourceSharding,
MeshSharding targetSharding) { … }
static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshSharding sourceSharding,
MeshSharding targetSharding,
TypedValue<ShapedType> sourceShard) { … }
static std::optional<std::tuple<int64_t, MeshAxis>>
detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
MeshSharding targetSharding) { … }
static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
MeshSharding sourceSharding,
int64_t splitTensorAxis) { … }
static ShapedType allGatherResultShapeInUnsplitLastAxis(
ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) { … }
static std::tuple<TypedValue<ShapedType>, MeshSharding>
unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshSharding sourceSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) { … }
static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshSharding sourceSharding,
MeshSharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) { … }
static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
MeshSharding targetSharding) { … }
static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
MeshSharding sourceSharding,
int64_t sourceTensorAxis,
int64_t targetTensorAxis) { … }
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
int64_t splitCount,
int64_t sourceTensorAxis,
int64_t targetTensorAxis) { … }
static std::tuple<TypedValue<ShapedType>, MeshSharding>
moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshSharding sourceSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard,
int64_t sourceTensorAxis,
int64_t targetTensorAxis, MeshAxis meshAxis) { … }
static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshSharding sourceSharding,
MeshSharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) { … }
static TypedValue<ShapedType>
reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshSharding sourceSharding, MeshSharding targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) { … }
TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshSharding sourceSharding,
MeshSharding targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) { … }
TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue) { … }
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue,
SymbolTableCollection &symbolTableCollection) { … }
void reshardingRegisterDependentDialects(DialectRegistry ®istry) { … }
#define GEN_PASS_DEF_SPMDIZATION
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
UnshardedToShardedValueMap;
SmallVector<Type>
shardedBlockArgumentTypes(Block &block,
SymbolTableCollection &symbolTableCollection) { … }
void spmdizeTriviallyShardableOperation(Operation &op,
ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder);
static LogicalResult spmdizeOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { … }
static std::vector<MeshSharding> getOperandShardings(Operation &op) { … }
static std::vector<MeshSharding> getResultShardings(Operation &op) { … }
static LogicalResult
spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) { … }
static LogicalResult
spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) { … }
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) { … }
static LogicalResult
spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection) { … }
namespace {
struct Spmdization : public impl::SpmdizationBase<Spmdization> { … };
}
}