llvm/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp

//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/Transforms/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <vector>

namespace mlir {
namespace mesh {
#define GEN_PASS_DEF_SHARDINGPROPAGATION
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
} // namespace mesh
} // namespace mlir

#define DEBUG_TYPE
#define DBGS()

usingnamespacemlir;
usingnamespacemlir::mesh;

enum class ReshardingRquirementKind {};

#ifdef LLVM_DEBUG

template <typename T>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
                                     const SmallVector<T> &vec);
template <typename... Ts>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
                                     const std::tuple<Ts...> &t);
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
                                     ReshardingRquirementKind v);

template <typename Stream, typename Range>
static Stream &printRange(Stream &stream, Range &&range) {}

template <typename T>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
                                     const SmallVector<T> &vec) {}

[[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
                                                      const ShardingOption &v) {}

template <typename Stream, typename... Ts, size_t... Is>
static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
                          std::index_sequence<Is...>) {}

template <typename... Ts>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
                                     const std::tuple<Ts...> &t) {}

[[maybe_unused]] static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {}

#endif // LLVM_DEBUG

//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//

// This method retrieves all potential sharding attributes, prioritizing
// specific shardings. For example, mustShardings = [shard0, None] and
// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
// [shard0, None]]
static SmallVector<std::vector<MeshSharding>>
getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
                                ArrayRef<MeshSharding> optionalShardings) {}

// The order of preference is form highest to lowest:
// 1. No resharding is required (all existing annotations are compatible).
// 2. No resharding for operands/results that have annotation specifically
//   targeting this operation. This means
//   * operands that are the result of `mesh.shard` ops marked with
//     `annotate_for_users`.
//   * results that are annotated with `mesh.shard` ops without
//     `annotate_for_users`.
// 3. All other cases. Resharding is required for operands/results with
//   annotation targeting explicitly this operation.
ReshardingRquirementKind getReshardingRquirementKind(
    Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {}

// From all the operand and result sharding combinations,
// return the one that is most desirable.
// The order of preference is:
// 1. No resharding with respect to existing sharding annotations.
// 2. Resharding for values that have already annotations that do not target
//    this op.
// 3. Resharding of existing explicit sharding annotations for this op.
static FailureOr<ShardingOption> selectShardingOption(
    ShardingInterface shardingOp,
    ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
    ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {}

// For each operation that implements the ShardingInterface, infer the sharding
// option of the operation from its operands and/or results using the
// `getShardingOption` method. If the inferred sharding option is not empty, add
// a `mesh.shard` operation for all remaining operands and results that do not
// have sharding annotations.
static LogicalResult visitOp(Operation *op, OpBuilder &builder) {}

//===----------------------------------------------------------------------===//
// ShardingPropagation
//===----------------------------------------------------------------------===//
struct ShardingPropagation
    : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {};