#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>
#define DEBUG_TYPE …
#define DBGS() …
usingnamespacemlir;
usingnamespacemlir::mesh;
#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
namespace {
struct DimensionSize { … };
}
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) { … }
static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) { … }
void MeshDialect::initialize() { … }
Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) { … }
static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTable) { … }
template <typename It>
bool isUnique(It begin, It end) { … }
static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
MeshOp mesh) { … }
template <typename Op>
static FailureOr<MeshOp>
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) { … }
template <typename InShape, typename MeshShape, typename SplitAxes,
typename OutShape>
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
const SplitAxes &splitAxes, OutShape &outShape,
ArrayRef<int64_t> shardedDimsSizes = { … }
ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
MeshSharding sharding) { … }
Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) { … }
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder) { … }
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) { … }
void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder) { … }
LogicalResult MeshOp::verify() { … }
LogicalResult
MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshOp mesh) { … }
void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshOp mesh, ArrayRef<MeshAxis> axes) { … }
void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef mesh, ArrayRef<MeshAxis> axes) { … }
void MeshShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh,
ArrayRef<MeshAxesAttr> split_axes,
ArrayRef<MeshAxis> partial_axes,
mesh::ReductionKind partial_type,
ArrayRef<int64_t> static_halo_sizes,
ArrayRef<int64_t> static_sharded_dims_sizes) { … }
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh,
ArrayRef<MeshAxesAttr> split_axes) { … }
void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) { … }
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
mlir::mesh::MeshSharding from) { … }
LogicalResult ShardingOp::verify() { … }
void ShardingOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const { … }
bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const { … }
bool MeshSharding::operator==(Value rhs) const { … }
bool MeshSharding::operator!=(Value rhs) const { … }
bool MeshSharding::operator==(const MeshSharding &rhs) const { … }
bool MeshSharding::operator!=(const MeshSharding &rhs) const { … }
MeshSharding::MeshSharding(Value rhs) { … }
MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxesAttr> split_axes_,
ArrayRef<MeshAxis> partial_axes_,
ReductionKind partial_type_,
ArrayRef<int64_t> static_halo_sizes_,
ArrayRef<int64_t> static_sharded_dims_sizes_,
ArrayRef<Value> dynamic_halo_sizes_,
ArrayRef<Value> dynamic_sharded_dims_sizes_) { … }
void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
::llvm::ArrayRef<int64_t> shape,
::mlir::Value sharding, ::mlir::Value device) { … }
void ShardOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult
ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshOp mesh) { … }
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef mesh, ArrayRef<MeshAxis> axes) { … }
void ProcessMultiIndexOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult
ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState, MeshOp mesh) { … }
void ProcessLinearIndexOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
namespace {
template <typename Op>
struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> { … };
}
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
ArrayRef<int64_t> device,
Operation::operand_range deviceDynamic,
ArrayRef<MeshAxis> meshAxes,
ArrayRef<int64_t> meshShape) { … }
template <typename It>
static auto product(It begin, It end) { … }
template <typename R>
static auto product(R &&range) { … }
static LogicalResult verifyDimensionCompatibility(Location loc,
int64_t expectedDimSize,
int64_t resultDimSize,
int64_t resultAxis) { … }
static LogicalResult verifyGatherOperandAndResultShape(
Value operand, Value result, int64_t gatherAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { … }
static LogicalResult verifyAllToAllOperandAndResultShape(
Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { … }
static LogicalResult verifyScatterOrSliceOperandAndResultShape(
Value operand, Value result, int64_t tensorAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { … }
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
ArrayRef<MeshAxis> meshAxes,
int64_t sliceAxis) { … }
LogicalResult
AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void AllGatherOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult
AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
Value input, StringRef mesh,
ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) { … }
void AllReduceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
int64_t sliceAxis) { … }
void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
Type resultType, Value input, StringRef mesh,
ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) { … }
void AllSliceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void AllToAllOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void BroadcastOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void GatherOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void ReduceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult
ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void ReduceScatterOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void ScatterOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void ShiftOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { … }
LogicalResult
UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) { … }
#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"