#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
namespace mesh {
MeshAxis;
MeshAxesAttr;
ShardShapeAttr;
HaloSizePairAttr;
}
}
#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
namespace mlir {
namespace mesh {
class MeshSharding { … };
}
}
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
namespace mlir {
namespace mesh {
inline bool isReductionLoop(utils::IteratorType iType) { … }
template <typename T>
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) { … }
inline bool isFullReplication(MeshSharding sharding) { … }
inline mesh::MeshOp
getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTableCollection) { … }
inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTableCollection) { … }
template <typename Op>
mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) { … }
template <>
inline mesh::MeshOp
getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) { … }
template <typename MeshAxesRange, typename MeshShapeRange>
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
MeshShapeRange &&meshShape) { … }
template <typename MeshAxesRange>
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) { … }
inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) { … }
inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) { … }
ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
MeshSharding sharding);
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);
}
}
#endif