llvm/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/IR/MeshOps.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>

#define DEBUG_TYPE
#define DBGS()

usingnamespacemlir;
usingnamespacemlir::mesh;

#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"

namespace {

struct DimensionSize {};

} // namespace

static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {}

static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {}

//===----------------------------------------------------------------------===//
// Inliner
//===----------------------------------------------------------------------===//

namespace {
struct MeshInlinerInterface : public DialectInlinerInterface {};
} // namespace

//===----------------------------------------------------------------------===//
// Mesh dialect
//===----------------------------------------------------------------------===//

void MeshDialect::initialize() {}

Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
                                            Type type, Location loc) {}

//===----------------------------------------------------------------------===//
// Mesh utilities
//===----------------------------------------------------------------------===//

static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
                                          FlatSymbolRefAttr meshSymbol,
                                          SymbolTableCollection &symbolTable) {}

template <typename It>
bool isUnique(It begin, It end) {}

static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
                                    MeshOp mesh) {}

template <typename Op>
static FailureOr<MeshOp>
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {}

template <typename InShape, typename MeshShape, typename SplitAxes,
          typename OutShape>
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
                       const SplitAxes &splitAxes, OutShape &outShape,
                       ArrayRef<int64_t> shardedDimsSizes = {}

ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
                                 MeshSharding sharding) {}

Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {}

void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                     OpOperand &operand,
                                                     OpBuilder &builder) {}

void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                     OpResult result,
                                                     OpBuilder &builder) {}

void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
                                                     OpOperand &operand,
                                                     OpBuilder &builder) {}

//===----------------------------------------------------------------------===//
// mesh.mesh op
//===----------------------------------------------------------------------===//

LogicalResult MeshOp::verify() {}

//===----------------------------------------------------------------------===//
// mesh.mesh_shape op
//===----------------------------------------------------------------------===//

LogicalResult
MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                        MeshOp mesh) {}

void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                        MeshOp mesh, ArrayRef<MeshAxis> axes) {}

void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                        StringRef mesh, ArrayRef<MeshAxis> axes) {}

void MeshShapeOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.sharding
//===----------------------------------------------------------------------===//

void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                       FlatSymbolRefAttr mesh,
                       ArrayRef<MeshAxesAttr> split_axes,
                       ArrayRef<MeshAxis> partial_axes,
                       mesh::ReductionKind partial_type,
                       ArrayRef<int64_t> static_halo_sizes,
                       ArrayRef<int64_t> static_sharded_dims_sizes) {}

void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                       FlatSymbolRefAttr mesh,
                       ArrayRef<MeshAxesAttr> split_axes) {}

void ShardingOp::build(
    ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
    FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
    ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {}

void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                       mlir::mesh::MeshSharding from) {}

LogicalResult ShardingOp::verify() {}

void ShardingOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

//===----------------------------------------------------------------------===//
// MeshSharding
//===----------------------------------------------------------------------===//

bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {}

bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {}

bool MeshSharding::operator==(Value rhs) const {}

bool MeshSharding::operator!=(Value rhs) const {}

bool MeshSharding::operator==(const MeshSharding &rhs) const {}

bool MeshSharding::operator!=(const MeshSharding &rhs) const {}

MeshSharding::MeshSharding(Value rhs) {}

MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
                               ArrayRef<MeshAxesAttr> split_axes_,
                               ArrayRef<MeshAxis> partial_axes_,
                               ReductionKind partial_type_,
                               ArrayRef<int64_t> static_halo_sizes_,
                               ArrayRef<int64_t> static_sharded_dims_sizes_,
                               ArrayRef<Value> dynamic_halo_sizes_,
                               ArrayRef<Value> dynamic_sharded_dims_sizes_) {}

//===----------------------------------------------------------------------===//
// mesh.shard_shape
//===----------------------------------------------------------------------===//

void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
                         ::mlir::OperationState &odsState,
                         ::llvm::ArrayRef<int64_t> shape,
                         ::mlir::Value sharding, ::mlir::Value device) {}

//===----------------------------------------------------------------------===//
// mesh.shard op
//===----------------------------------------------------------------------===//

void ShardOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.process_multi_index op
//===----------------------------------------------------------------------===//

LogicalResult
ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                                MeshOp mesh) {}

void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                                StringRef mesh, ArrayRef<MeshAxis> axes) {}

void ProcessMultiIndexOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.process_linear_index op
//===----------------------------------------------------------------------===//

LogicalResult
ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
                                 OperationState &odsState, MeshOp mesh) {}

void ProcessLinearIndexOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//

namespace {

template <typename Op>
struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {};

} // namespace

static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
                                         ArrayRef<int64_t> device,
                                         Operation::operand_range deviceDynamic,
                                         ArrayRef<MeshAxis> meshAxes,
                                         ArrayRef<int64_t> meshShape) {}

template <typename It>
static auto product(It begin, It end) {}

template <typename R>
static auto product(R &&range) {}

static LogicalResult verifyDimensionCompatibility(Location loc,
                                                  int64_t expectedDimSize,
                                                  int64_t resultDimSize,
                                                  int64_t resultAxis) {}

static LogicalResult verifyGatherOperandAndResultShape(
    Value operand, Value result, int64_t gatherAxis,
    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {}

static LogicalResult verifyAllToAllOperandAndResultShape(
    Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {}

static LogicalResult verifyScatterOrSliceOperandAndResultShape(
    Value operand, Value result, int64_t tensorAxis,
    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {}

static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
                                        ArrayRef<MeshAxis> meshAxes,
                                        int64_t sliceAxis) {}

//===----------------------------------------------------------------------===//
// mesh.all_gather op
//===----------------------------------------------------------------------===//

LogicalResult
AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                              MLIRContext *context) {}

void AllGatherOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.all_reduce op
//===----------------------------------------------------------------------===//

LogicalResult
AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                              MLIRContext *context) {}

void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                        Value input, StringRef mesh,
                        ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {}

void AllReduceOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.all_slice op
//===----------------------------------------------------------------------===//

LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {}

void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                       Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
                       int64_t sliceAxis) {}

void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                       Type resultType, Value input, StringRef mesh,
                       ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {}

void AllSliceOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.all_to_all op
//===----------------------------------------------------------------------===//

LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {}

void AllToAllOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.broadcast op
//===----------------------------------------------------------------------===//

LogicalResult
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                              MLIRContext *context) {}

void BroadcastOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.gather op
//===----------------------------------------------------------------------===//

LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                           MLIRContext *context) {}

void GatherOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.recv op
//===----------------------------------------------------------------------===//

LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                         MLIRContext *context) {}

void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.reduce op
//===----------------------------------------------------------------------===//

LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                           MLIRContext *context) {}

void ReduceOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.reduce_scatter op
//===----------------------------------------------------------------------===//

LogicalResult
ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                  MLIRContext *context) {}

void ReduceScatterOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.scatter op
//===----------------------------------------------------------------------===//

LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                            MLIRContext *context) {}

void ScatterOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.send op
//===----------------------------------------------------------------------===//

LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                         MLIRContext *context) {}

void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.shift op
//===----------------------------------------------------------------------===//

LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                          MLIRContext *context) {}

void ShiftOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// mesh.update_halo op
//===----------------------------------------------------------------------===//

LogicalResult
UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"

#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"