llvm/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp

//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>

namespace mlir::linalg {

MeshAxis;
ReductionKind;
MeshSharding;
ShardingArray;
MeshOp;

// Returns the corresponding mesh reduction kind for the given arith op.
static ReductionKind getReductionKind(Operation *op) {}

static std::optional<Operation *> getCombinerOp(LinalgOp op) {}

static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {}

static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
                      ArrayRef<MeshSharding> resultShardings,
                      SymbolTableCollection &symbolTable) {}

// Choose the operand based on the current process index along the reduction
// mesh axes.
// We need to use the initial value only once to avoid including it in the
// reduction multiple times.
// In each process group only the leading process with linear index 0 would use
// the original operand.
// The other processes would use the reduction operation neutral tensor.
static Value createDestinationPassingStyleInitOperand(
    LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
    MeshOp meshOp, ImplicitLocOpBuilder &builder) {}

// Create the DPS init operands for the spmdized Linalg op.
// Return all the new spmdized operands.
static SmallVector<Value> createDestinationPassingStyleInitOperands(
    LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
    ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
    ImplicitLocOpBuilder &builder) {}

static void createAllReduceForResultWithoutPartialSharding(
    Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
    MeshSharding resultSharding, ReductionKind reductionKind,
    IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {}

static void createAllReduceForResultsWithoutPartialShardings(
    LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
    ImplicitLocOpBuilder &builder) {}

static void spmdizeLinalgOpWithShardedReduction(
    LinalgOp op, ArrayRef<Value> spmdizedOperands,
    ArrayRef<MeshSharding> operandShardings,
    ArrayRef<MeshSharding> resultShardings,
    ArrayRef<utils::IteratorType> loopIteratorTypes,
    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
    IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
    ImplicitLocOpBuilder &builder) {}

namespace {

// ShardingInterface for ops that implement LinalgStructuredInterface.
// The supported ops are only those where the indexing maps are projected
// permutations.
template <typename Op>
struct StructuredOpShardingInterface
    : public mesh::ShardingInterface::ExternalModel<
          StructuredOpShardingInterface<Op>, Op> {};

} // namespace

template <typename OpType>
static void registerOne(MLIRContext *ctx) {}

/// Variadic helper function.
template <typename... OpTypes>
static void registerAll(MLIRContext *ctx) {}

void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {}

} // namespace mlir::linalg