llvm/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

//===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
#define MLIR_DIALECT_MESH_IR_MESHOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Support/MathExtras.h"

namespace mlir {
namespace mesh {

MeshAxis;
MeshAxesAttr;
ShardShapeAttr;
HaloSizePairAttr;

} // namespace mesh
} // namespace mlir

#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"

namespace mlir {
namespace mesh {

class MeshSharding {};

} // namespace mesh
} // namespace mlir

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"

namespace mlir {
namespace mesh {

inline bool isReductionLoop(utils::IteratorType iType) {}

template <typename T>
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {}

// Is the same tensor replicated on all processes.
inline bool isFullReplication(MeshSharding sharding) {}

inline mesh::MeshOp
getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
              SymbolTableCollection &symbolTableCollection) {}

inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
                            SymbolTableCollection &symbolTableCollection) {}

// Get the corresponding mesh op using the standard attribute nomenclature.
template <typename Op>
mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {}

template <>
inline mesh::MeshOp
getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {}

// Get the number of processes that participate in each group
// induced by `meshAxes`.
template <typename MeshAxesRange, typename MeshShapeRange>
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
                                   MeshShapeRange &&meshShape) {}

template <typename MeshAxesRange>
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {}

// Get the size of a sharded dimension.
inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {}

// Get the size of an unsharded dimension.
inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {}

// Return the sharded shape `shape` according ot sharding `sharding`.
// The shape for the tensor on each device in the mesh.
// Example:
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
// result in a shape for each shard of ?x2x?.
ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
                           MeshSharding sharding);

// If ranked tensor type return its sharded counterpart.
//
// If not ranked tensor type return `type`.
// `sharding` in that case must be null.
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);

// Insert shard op if there is not one that already has the same sharding.
// May insert resharding if required.
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                         OpOperand &operand,
                                         OpBuilder &builder);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
                                         OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
                                         OpOperand &operand,
                                         OpBuilder &builder);

} // namespace mesh
} // namespace mlir

#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H