llvm/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

//===- Spmdization.cpp --------------------------------------------- C++ --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include <iterator>
#include <optional>
#include <tuple>
#include <type_traits>

namespace mlir::mesh {

template <typename SourceAxes, typename TargetAxes>
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
                                     const TargetAxes &targetAxes) {}

// Return the reduced value and its corresponding sharding.
// Example:
// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
// targetSharding = <@mesh_1d, [[]]>
// Then will apply all-reduce on the source value
// and return it with the sharding <@mesh_1d, [[0]]>.
static std::tuple<TypedValue<ShapedType>, MeshSharding>
handlePartialAxesDuringResharding(OpBuilder &builder,
                                  MeshSharding sourceSharding,
                                  MeshSharding targetSharding,
                                  TypedValue<ShapedType> sourceShard) {}

static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
                                                  MeshSharding sourceSharding,
                                                  int64_t splitTensorAxis,
                                                  MeshAxis splitMeshAxis) {}

// Split a replicated tensor along a mesh axis.
// e.g. [[0, 1]] -> [[0, 1, 2]].
// Returns the spmdized target value with its sharding.
static std::tuple<TypedValue<ShapedType>, MeshSharding>
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                          MeshSharding sourceSharding,
                          TypedValue<ShapedType> sourceShard, MeshOp mesh,
                          int64_t splitTensorAxis, MeshAxis splitMeshAxis) {}

// Detect if the resharding is of type e.g.
// [[0, 1]] -> [[0, 1, 2]].
// If detected, returns the corresponding tensor axis mesh axis pair.
// Does not detect insertions like
// [[0, 1]] -> [[0, 2, 1]].
static std::optional<std::tuple<int64_t, MeshAxis>>
detectSplitLastAxisInResharding(MeshSharding sourceSharding,
                                MeshSharding targetSharding) {}

static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                             MeshSharding sourceSharding,
                             MeshSharding targetSharding,
                             TypedValue<ShapedType> sourceShard) {}

// Detect if the resharding is of type e.g.
// [[0, 1, 2]] -> [[0, 1]].
// If detected, returns the corresponding tensor axis mesh axis pair.
static std::optional<std::tuple<int64_t, MeshAxis>>
detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
                                  MeshSharding targetSharding) {}

static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
                                                    MeshSharding sourceSharding,
                                                    int64_t splitTensorAxis) {}

static ShapedType allGatherResultShapeInUnsplitLastAxis(
    ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {}

static std::tuple<TypedValue<ShapedType>, MeshSharding>
unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                            MeshSharding sourceSharding,
                            ShapedType sourceUnshardedShape,
                            TypedValue<ShapedType> sourceShard, MeshOp mesh,
                            int64_t splitTensorAxis, MeshAxis splitMeshAxis) {}

static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                               MeshSharding sourceSharding,
                               MeshSharding targetSharding,
                               ShapedType sourceUnshardedShape,
                               TypedValue<ShapedType> sourceShard) {}

// Detect if the resharding is of type e.g.
// [[0, 1], [2]] -> [[0], [1, 2]].
// Only moving the last axis counts.
// If detected, returns the corresponding (source_tensor_axis,
// target_tensor_axis, mesh_axis) tuple.
static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
                                    MeshSharding targetSharding) {}

static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
                                                 MeshSharding sourceSharding,
                                                 int64_t sourceTensorAxis,
                                                 int64_t targetTensorAxis) {}

static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
                                                    int64_t splitCount,
                                                    int64_t sourceTensorAxis,
                                                    int64_t targetTensorAxis) {}

static std::tuple<TypedValue<ShapedType>, MeshSharding>
moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                              MeshSharding sourceSharding,
                              ShapedType sourceUnshardedShape,
                              TypedValue<ShapedType> sourceShard,
                              int64_t sourceTensorAxis,
                              int64_t targetTensorAxis, MeshAxis meshAxis) {}

static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                 MeshSharding sourceSharding,
                                 MeshSharding targetSharding,
                                 ShapedType sourceUnshardedShape,
                                 TypedValue<ShapedType> sourceShard) {}

// Handles only resharding on a 1D mesh.
// Currently the sharded tensor axes must be exactly divisible by the single
// mesh axis size.
static TypedValue<ShapedType>
reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
                MeshSharding sourceSharding, MeshSharding targetSharding,
                TypedValue<ShapedType> sourceUnshardedValue,
                TypedValue<ShapedType> sourceShard) {}

TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
                               MeshSharding sourceSharding,
                               MeshSharding targetSharding,
                               TypedValue<ShapedType> sourceUnshardedValue,
                               TypedValue<ShapedType> sourceShard) {}

TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
                               ShardOp target,
                               TypedValue<ShapedType> sourceShardValue) {}

TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
                               ShardOp target,
                               TypedValue<ShapedType> sourceShardValue,
                               SymbolTableCollection &symbolTableCollection) {}

void reshardingRegisterDependentDialects(DialectRegistry &registry) {}

#define GEN_PASS_DEF_SPMDIZATION
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"

UnshardedToShardedValueMap;

// Get the types of block arguments for an spmdized block.
// Reads the sharding annotations of the arguments to deduce the sharded types.
// Types that are not ranked tensors are left unchanged.
SmallVector<Type>
shardedBlockArgumentTypes(Block &block,
                          SymbolTableCollection &symbolTableCollection) {}

void spmdizeTriviallyShardableOperation(Operation &op,
                                        ArrayRef<Value> spmdizedOperands,
                                        ArrayRef<MeshSharding> operandShardings,
                                        ArrayRef<MeshSharding> resultShardings,
                                        IRMapping &spmdizationMap,
                                        SymbolTableCollection &symbolTable,
                                        OpBuilder &builder);

static LogicalResult spmdizeOperation(
    Operation &op, ArrayRef<Value> spmdizedOperands,
    ArrayRef<MeshSharding> operandShardings,
    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
    SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {}

// Retrieve the sharding annotations for the operands of the given operation.
// If the type is not a ranked tensor it is not require to have an annotation.
static std::vector<MeshSharding> getOperandShardings(Operation &op) {}

// Retrieve the sharding annotations for the results of the given operation.
// If the type is not a ranked tensor it is not require to have an annotation.
static std::vector<MeshSharding> getResultShardings(Operation &op) {}

static LogicalResult
spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
                 SymbolTableCollection &symbolTableCollection,
                 OpBuilder &builder) {}

static LogicalResult
spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
                 SymbolTableCollection &symbolTableCollection,
                 OpBuilder &builder) {}

static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
                                  SymbolTableCollection &symbolTableCollection,
                                  OpBuilder &builder) {}

static LogicalResult
spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
              SymbolTableCollection &symbolTableCollection) {}

namespace {

struct Spmdization : public impl::SpmdizationBase<Spmdization> {};

} // namespace

} // namespace mlir::mesh