llvm/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td

//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD

include "mlir/IR/OpBase.td"

def ShardingInterface : OpInterface<"ShardingInterface"> {
    let description = [{
        Interface for allowing operations to expose information needed to
        shard them.
    }];
    let cppNamespace = "::mlir::mesh";

    let methods = [
      InterfaceMethod<
        /*desc=*/[{
          Returns a list of iterator types that describe the number of loops.
          The iterator types determine how the operation traverses its input and
          output tensors.

          Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
          types are parallel, parallel, reduction. This indicates that M and
          N are traversed in parallel, while the K dimension is used for
          reduction.
        }],
        /*retType=*/"SmallVector<mlir::utils::IteratorType>",
        /*methodName=*/"getLoopIteratorTypes",
        /*args=*/(ins),
        /*methodBody=*/"",
        /*defaultImplementation=*/"return {};"
      >,
      InterfaceMethod<
        /*desc=*/[{
          Return the kind of all reduction loop iterators.
          The order is the same as the same as the result from
          `getLoopIteratorTypes`.

          Example 1:
          iterator types =  (parallel, reduction, parallel, reduction)
                                             ||                   ||
          reduction kinds = (                sum,                 max)

          Example 2:
          A softmax op's loop iterator types are parallel and
          reduction.
          The reduction iterator will be of kind `generic`, since it is non of
          the available presets.
        }],
        /*retType=*/"SmallVector<ReductionKind>",
        /*methodName=*/"getReductionLoopIteratorKinds",
        /*args=*/(ins),
        /*methodBody=*/"",
        /*defaultImplementation=*/"return {};"
      >,
      InterfaceMethod<
        /*desc=*/[{
          Return the indexing maps attribute within the current operation.
          Indexing maps determine how indices in the iteration space map to
          tensor indices. They are specified using `affine_map` in MLIR, which
          provides an affine transformation of indices.
        }],
        /*retTy=*/"SmallVector<AffineMap>",
        /*methodName=*/"getIndexingMaps",
        /*args=*/(ins),
        /*methodBody=*/"",
        /*defaultImplementation=*/"return {};"
      >,
      InterfaceMethod<
        /*desc=*/[{
          Given that certain operands or results of the operation may have
          sharding annotations, this method leverages this information to
          deduce how the operation should be sharded.
          The passed sharding may be incomplete, this gives freedom for the
          op to select the most appropriate shardings for all the operands
          and results and the op itself.
        }],
        /*retTy=*/"FailureOr<ShardingOption>",
        /*methodName=*/"getShardingOption",
        /*args=*/(ins
          "ArrayRef<MeshSharding>": $operandShardings,
          "ArrayRef<MeshSharding>": $resultShardings
        ),
        /*methodBody=*/"",
        /*defaultImplementation=*/[{
          return detail::defaultGetShardingOption(
            $_op.getOperation(), operandShardings, resultShardings);
        }]
      >,
      InterfaceMethod<
        /*desc=*/[{
          Based on a given ShardingOption, get the operand and result
          operations for the operands and results sharding annotations.
          This is what shardings the operands and results need to have in order
          to shard the op according to shardingOption.
        }],
        /*retTy=*/"FailureOr<std::vector<MeshSharding>>",
        /*methodName=*/"getShardingAnnotations",
        /*args=*/(ins
          "const ShardingOption &":$shardingOption
        ),
        /*methodBody=*/"",
        /*defaultImplementation=*/[{
          return detail::defaultGetShardingAnnotations(
            $_op.getOperation(), shardingOption);
        }]
      >,
      InterfaceMethod<
        /*desc=*/[{
          Based on a given ShardingOption, this method adds `mesh.shard`
          operations for the operands and results that previously lacked
          sharding annotations.
        }],
        /*retTy=*/"LogicalResult",
        /*methodName=*/"addShardingAnnotations",
        /*args=*/(ins
          "OpBuilder &":$b,
          "const ShardingOption &":$shardingOption
        ),
        /*methodBody=*/"",
        /*defaultImplementation=*/[{
          return detail::defaultAddShardingAnnotations(
            $_op.getOperation(), b, shardingOption);
        }]
      >,
      InterfaceMethod<
        /*desc=*/[{
          Convert self to SPMD form.
          This method is used during the spmdization pass of a program fully
          annotated with shardings.

          The spmdization algorithm would read the surrounding sharding
          annotations from the IR for each argument/result and prepare
          `operandShardings` and `resultShardings`.
          Values that are not ranked tensors do not have sharding annotations.
          In this case their corresponding MeshSharding is null.

          For convenience it will also prepare `spmdizedOperands`, although
          they can be retrieved from the `spmdizationMap`.

          The `spmdizationMap` contains a mapping from unsharded to
          sharded/spmdized values that are constructed during the spmdization
          pass. The interface implementation must populate `spmdizationMap`
          with the mapping for this op's results.

          `builder` is set to insert new operations in the appropriate point.
          The implementation should not return the builder to the original
          insertion point.
          It should leave it as is after all insertions are done.

          The default implementation does full replication.
          This assumes that all sharding annotations are for full replication.
        }],
        /*retTy=*/"LogicalResult",
        /*methodName=*/"spmdize",
        /*args=*/(ins
          "ArrayRef<Value>": $spmdizedOperands,
          "ArrayRef<MeshSharding>": $operandShardings,
          "ArrayRef<MeshSharding>": $resultShardings,
          "IRMapping&": $spmdizationMap,
          "SymbolTableCollection &": $symbolTableCollection,
          "OpBuilder &":$builder
        ),
        /*methodBody=*/"",
        /*defaultImplementation=*/[{
          spmdizeFullyReplicatedOperation(
            *$_op.getOperation(), spmdizedOperands, operandShardings,
              resultShardings, spmdizationMap, symbolTableCollection, builder);
          return success();
        }]>
    ];

    let extraClassDeclaration = [{
      LogicalResult verifyShardingInterfaceImpl();

      void printLoopTypesAndIndexingMaps(raw_ostream &os);
    }];
}


#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD