llvm/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td

//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//


#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD

include "mlir/Pass/PassBase.td"

//===----------------------------------------------------------------------===//
// ShardingPropagation
//===----------------------------------------------------------------------===//

def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> {
  let summary = "sharding propagation";
  let description = [{
    Propagates sharding information throughout the graph. After this pass, each
    of the operations' operands and results is annotated with a `mesh.shard`
    operation, and the operations themselves are added with sharding option
    attributes.
  }];
  let dependentDialects = [
    "mesh::MeshDialect"
  ];
}

def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
  let summary = "Partition a function into SPMD form.";
  let description = [{
    This pass fits in right after a pass that annotates the function with
    shardings like the `ShardingPropagation` pass.
    It operates on a fully annotated IR.

    A fully annotated IR required that all ranked tensor operands, results and
    block arguments are annotated with the `mesh.shard` operation.
  
    All direct descendant operations in the function must implement the
    `ShardingInterface` interface or all their ranked tensor operands and
    results must have full replication sharding.

    The input IR must have sharding annotations such that each operation
    that implements `ShardingInterface` can handle during spmdization with
    its `spmdize` method.
    This can be achieved with the `ShardingPropagation` pass.

    If the function has multiple terminating blocks,
    it is the responsibility of the the one who annotates the function with
    shardings to make sure that all returns would be consisted that is,
    have the same sharding.

    Example:
    ```mlir
    mesh.mesh @mesh_1d(shape = 2)

    func.func @f(
      %arg0: tensor<2xi8>
    ) -> tensor<2xi8> {
      %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
      %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
      %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
      %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
      %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
      return %4 : tensor<2xi8>
    }
    ```
    Spmdizing the above would result in 
    * Performing the element-wise `abs` operation on each device.
    * Resharding to full replication with an all-gather.

    ```mlir
    mesh.mesh @mesh_1d(shape = 2)
  
    func.func @f(%arg0: tensor<1xi8>) -> tensor<2xi8> {
      %0 = tosa.abs %arg0 : (tensor<1xi8>) -> tensor<1xi8>
      %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
      return %1 : tensor<2xi8>
    }
    ```
  }];
  let dependentDialects = [
    "mesh::MeshDialect"
  ];
}

#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD