llvm/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp

//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"

#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"

#include <utility>

#define DEBUG_TYPE
#define DBGS()

usingnamespacemlir;
usingnamespacemlir::mesh;

#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"

//===----------------------------------------------------------------------===//
// common util functions
//===----------------------------------------------------------------------===//

static LogicalResult
checkOperandAffineExprRecursively(AffineExpr expr,
                                  SmallVectorImpl<bool> &seenIds) {}

static FailureOr<llvm::SmallSet<unsigned, 2>>
checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {}

template <typename T>
SmallVector<MeshAxesAttr>
fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) {}

//===----------------------------------------------------------------------===//
// mesh::getMeshSharding
//===----------------------------------------------------------------------===//

FailureOr<std::pair<bool, MeshSharding>>
mesh::getMeshSharding(OpResult result) {}

FailureOr<std::pair<bool, MeshSharding>>
mesh::getMeshSharding(OpOperand &opOperand) {}

//===----------------------------------------------------------------------===//
// ShardingInterface::verifyShardingInterfaceImpl
//===----------------------------------------------------------------------===//

LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {}

//===----------------------------------------------------------------------===//
// ShardingInterface::printLoopTypesAndIndexingMaps
//===----------------------------------------------------------------------===//

void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {}

//===----------------------------------------------------------------------===//
// detail::defaultGetShardingOption
//===----------------------------------------------------------------------===//

namespace {

// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
static LogicalResult fillShardingOption(Operation *op,
                                        ShardingOption &shardingOption,
                                        FlatSymbolRefAttr mesh,
                                        ArrayRef<MeshAxis> meshAxes,
                                        unsigned loopIdx) {}

} // namespace

FailureOr<ShardingOption>
mesh::detail::defaultGetShardingOption(Operation *op,
                                       ArrayRef<MeshSharding> operandShardings,
                                       ArrayRef<MeshSharding> resultShardings) {}

// Get the sharding attributed for the given result and sharding option.
MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
                         AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
                         ArrayRef<ReductionKind> reductionLoopKinds) {}

static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
                                           const ShardingOption &shardingOption,
                                           AffineMap map) {}

FailureOr<std::vector<MeshSharding>>
mesh::detail::defaultGetShardingAnnotations(
    Operation *op, const ShardingOption &shardingOption) {}

//===----------------------------------------------------------------------===//
// detail::defaultAddShardingAnnotations
//===----------------------------------------------------------------------===//

// To add a `mesh.shard` op for the given result, based on the details provided
// in `shardingOption`, `map`, and `loopTypes`.
static LogicalResult addShardOp(OpBuilder &b, OpResult result,
                                const ShardingOption &shardingOption,
                                AffineMap map,
                                ArrayRef<utils::IteratorType> loopTypes,
                                ArrayRef<ReductionKind> reductionLoopKinds) {}

// To add a `mesh.shard` op for the given operand, based on the details provided
// in `shardingOption`, `map`, and `loopTypes`.
static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
                                const ShardingOption &shardingOption,
                                AffineMap map) {}

LogicalResult mesh::detail::defaultAddShardingAnnotations(
    Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {}

#ifndef NDEBUG
static bool
isValueCompatibleWithFullReplicationSharding(Value value,
                                             MeshSharding sharding) {
  if (isa<RankedTensorType>(value.getType())) {
    return sharding && isFullReplication(sharding);
  }

  return !sharding;
}

template <typename ValueRange, typename MeshShardingRage>
static bool
areValuesCompatibleWithFullReplicationShardings(ValueRange &&values,
                                                MeshShardingRage &&shardings) {
  if (std::size(values) != std::size(shardings)) {
    return false;
  }
  return llvm::all_of(
      llvm::zip_equal(std::forward<ValueRange>(values),
                      std::forward<MeshShardingRage>(shardings)),
      [](auto valueAndSharding) {
        return isValueCompatibleWithFullReplicationSharding(
            std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
      });
}
#endif // NDEBUG

void mesh::spmdizeFullyReplicatedOperation(
    Operation &op, ArrayRef<Value> spmdizedOperands,
    ArrayRef<MeshSharding> operandShardings,
    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
    SymbolTableCollection &symbolTable, OpBuilder &builder) {}

static void updateMeshAxisAssignmentForLoopIterators(
    ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
    SmallVector<std::optional<SmallVector<MeshAxis>>>
        &meshAxesAssignmentForLoopIterators) {}

ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
    ArrayRef<MeshSharding> operandShardings,
    ArrayRef<MeshSharding> resultShardings,
    ArrayRef<utils::IteratorType> loopIteratorTypes,
    ArrayRef<AffineMap> indexingMaps) {}

bool mesh::isAtLeastOneReductionIteratorSharded(
    ArrayRef<utils::IteratorType> loopIteratorTypes,
    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {}

SmallVector<MeshAxis> mesh::getReductionMeshAxes(
    ArrayRef<utils::IteratorType> loopIteratorTypes,
    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {}

void mesh::spmdizeTriviallyShardableOperation(
    Operation &op, ArrayRef<Value> spmdizedOperands,
    ArrayRef<MeshSharding> operandShardings,
    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
    SymbolTableCollection &symbolTable, OpBuilder &builder) {}