#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>
namespace mlir::linalg {
MeshAxis;
ReductionKind;
MeshSharding;
ShardingArray;
MeshOp;
static ReductionKind getReductionKind(Operation *op) { … }
static std::optional<Operation *> getCombinerOp(LinalgOp op) { … }
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) { … }
static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
SymbolTableCollection &symbolTable) { … }
static Value createDestinationPassingStyleInitOperand(
LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
MeshOp meshOp, ImplicitLocOpBuilder &builder) { … }
static SmallVector<Value> createDestinationPassingStyleInitOperands(
LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
ImplicitLocOpBuilder &builder) { … }
static void createAllReduceForResultWithoutPartialSharding(
Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
MeshSharding resultSharding, ReductionKind reductionKind,
IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) { … }
static void createAllReduceForResultsWithoutPartialShardings(
LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
ImplicitLocOpBuilder &builder) { … }
static void spmdizeLinalgOpWithShardedReduction(
LinalgOp op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshSharding> operandShardings,
ArrayRef<MeshSharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
ImplicitLocOpBuilder &builder) { … }
namespace {
template <typename Op>
struct StructuredOpShardingInterface
: public mesh::ShardingInterface::ExternalModel<
StructuredOpShardingInterface<Op>, Op> { … };
}
template <typename OpType>
static void registerOne(MLIRContext *ctx) { … }
template <typename... OpTypes>
static void registerAll(MLIRContext *ctx) { … }
void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry) { … }
}