#ifdef GEN_PASS_DECL
#define GEN_PASS_DECL_SHARDINGPROPAGATION
#define GEN_PASS_DECL_SPMDIZATION
#undef GEN_PASS_DECL
#endif
#ifdef GEN_PASS_DECL_SHARDINGPROPAGATION
std::unique_ptr<::mlir::Pass> createShardingPropagation();
#undef GEN_PASS_DECL_SHARDINGPROPAGATION
#endif
#ifdef GEN_PASS_DEF_SHARDINGPROPAGATION
namespace impl {
std::unique_ptr<::mlir::Pass> createShardingPropagation();
}
namespace impl {
template <typename DerivedT>
class ShardingPropagationBase : public ::mlir::InterfacePass<mlir::FunctionOpInterface> {
public:
using Base = ShardingPropagationBase;
ShardingPropagationBase() : ::mlir::InterfacePass<mlir::FunctionOpInterface>(::mlir::TypeID::get<DerivedT>()) {}
ShardingPropagationBase(const ShardingPropagationBase &other) : ::mlir::InterfacePass<mlir::FunctionOpInterface>(other) {}
ShardingPropagationBase& operator=(const ShardingPropagationBase &) = delete;
ShardingPropagationBase(ShardingPropagationBase &&) = delete;
ShardingPropagationBase& operator=(ShardingPropagationBase &&) = delete;
~ShardingPropagationBase() = default;
static constexpr ::llvm::StringLiteral getArgumentName() {
return ::llvm::StringLiteral("sharding-propagation");
}
::llvm::StringRef getArgument() const override { return "sharding-propagation"; }
::llvm::StringRef getDescription() const override { return "sharding propagation"; }
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral("ShardingPropagation");
}
::llvm::StringRef getName() const override { return "ShardingPropagation"; }
static bool classof(const ::mlir::Pass *pass) {
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
}
std::unique_ptr<::mlir::Pass> clonePass() const override {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
registry.insert<mesh::MeshDialect>();
}
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShardingPropagationBase<DerivedT>)
protected:
private:
friend std::unique_ptr<::mlir::Pass> createShardingPropagation() {
return std::make_unique<DerivedT>();
}
};
}
std::unique_ptr<::mlir::Pass> createShardingPropagation() {
return impl::createShardingPropagation();
}
#undef GEN_PASS_DEF_SHARDINGPROPAGATION
#endif
#ifdef GEN_PASS_DECL_SPMDIZATION
std::unique_ptr<::mlir::Pass> createSpmdization();
#undef GEN_PASS_DECL_SPMDIZATION
#endif
#ifdef GEN_PASS_DEF_SPMDIZATION
namespace impl {
std::unique_ptr<::mlir::Pass> createSpmdization();
}
namespace impl {
template <typename DerivedT>
class SpmdizationBase : public ::mlir::InterfacePass<mlir::FunctionOpInterface> {
public:
using Base = SpmdizationBase;
SpmdizationBase() : ::mlir::InterfacePass<mlir::FunctionOpInterface>(::mlir::TypeID::get<DerivedT>()) {}
SpmdizationBase(const SpmdizationBase &other) : ::mlir::InterfacePass<mlir::FunctionOpInterface>(other) {}
SpmdizationBase& operator=(const SpmdizationBase &) = delete;
SpmdizationBase(SpmdizationBase &&) = delete;
SpmdizationBase& operator=(SpmdizationBase &&) = delete;
~SpmdizationBase() = default;
static constexpr ::llvm::StringLiteral getArgumentName() {
return ::llvm::StringLiteral("mesh-spmdization");
}
::llvm::StringRef getArgument() const override { return "mesh-spmdization"; }
::llvm::StringRef getDescription() const override { return "Partition a function into SPMD form."; }
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral("Spmdization");
}
::llvm::StringRef getName() const override { return "Spmdization"; }
static bool classof(const ::mlir::Pass *pass) {
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
}
std::unique_ptr<::mlir::Pass> clonePass() const override {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
registry.insert<mesh::MeshDialect>();
}
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpmdizationBase<DerivedT>)
protected:
private:
friend std::unique_ptr<::mlir::Pass> createSpmdization() {
return std::make_unique<DerivedT>();
}
};
}
std::unique_ptr<::mlir::Pass> createSpmdization() {
return impl::createSpmdization();
}
#undef GEN_PASS_DEF_SPMDIZATION
#endif
#ifdef GEN_PASS_REGISTRATION
inline void registerShardingPropagation() {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return createShardingPropagation();
});
}
inline void registerShardingPropagationPass() {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return createShardingPropagation();
});
}
inline void registerSpmdization() {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return createSpmdization();
});
}
inline void registerSpmdizationPass() {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return createSpmdization();
});
}
inline void registerMeshPasses() {
registerShardingPropagation();
registerSpmdization();
}
#undef GEN_PASS_REGISTRATION
#endif
#ifdef GEN_PASS_CLASSES
template <typename DerivedT>
class ShardingPropagationBase : public ::mlir::InterfacePass<mlir::FunctionOpInterface> {
public:
using Base = ShardingPropagationBase;
ShardingPropagationBase() : ::mlir::InterfacePass<mlir::FunctionOpInterface>(::mlir::TypeID::get<DerivedT>()) {}
ShardingPropagationBase(const ShardingPropagationBase &other) : ::mlir::InterfacePass<mlir::FunctionOpInterface>(other) {}
ShardingPropagationBase& operator=(const ShardingPropagationBase &) = delete;
ShardingPropagationBase(ShardingPropagationBase &&) = delete;
ShardingPropagationBase& operator=(ShardingPropagationBase &&) = delete;
~ShardingPropagationBase() = default;
static constexpr ::llvm::StringLiteral getArgumentName() {
return ::llvm::StringLiteral("sharding-propagation");
}
::llvm::StringRef getArgument() const override { return "sharding-propagation"; }
::llvm::StringRef getDescription() const override { return "sharding propagation"; }
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral("ShardingPropagation");
}
::llvm::StringRef getName() const override { return "ShardingPropagation"; }
static bool classof(const ::mlir::Pass *pass) {
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
}
std::unique_ptr<::mlir::Pass> clonePass() const override {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
registry.insert<mesh::MeshDialect>();
}
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShardingPropagationBase<DerivedT>)
protected:
};
template <typename DerivedT>
class SpmdizationBase : public ::mlir::InterfacePass<mlir::FunctionOpInterface> {
public:
using Base = SpmdizationBase;
SpmdizationBase() : ::mlir::InterfacePass<mlir::FunctionOpInterface>(::mlir::TypeID::get<DerivedT>()) {}
SpmdizationBase(const SpmdizationBase &other) : ::mlir::InterfacePass<mlir::FunctionOpInterface>(other) {}
SpmdizationBase& operator=(const SpmdizationBase &) = delete;
SpmdizationBase(SpmdizationBase &&) = delete;
SpmdizationBase& operator=(SpmdizationBase &&) = delete;
~SpmdizationBase() = default;
static constexpr ::llvm::StringLiteral getArgumentName() {
return ::llvm::StringLiteral("mesh-spmdization");
}
::llvm::StringRef getArgument() const override { return "mesh-spmdization"; }
::llvm::StringRef getDescription() const override { return "Partition a function into SPMD form."; }
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral("Spmdization");
}
::llvm::StringRef getName() const override { return "Spmdization"; }
static bool classof(const ::mlir::Pass *pass) {
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
}
std::unique_ptr<::mlir::Pass> clonePass() const override {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
registry.insert<mesh::MeshDialect>();
}
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpmdizationBase<DerivedT>)
protected:
};
#undef GEN_PASS_CLASSES
#endif