llvm/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

//===- MeshBase.td - Mesh Dialect --------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MESH_IR_MESHBASE_TD
#define MLIR_DIALECT_MESH_IR_MESHBASE_TD

include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/EnumAttr.td"

//===----------------------------------------------------------------------===//
// Mesh Dialect
//===----------------------------------------------------------------------===//

def Mesh_Dialect : Dialect {
  let name = "mesh";
  let cppNamespace = "::mlir::mesh";

  let description = [{
    See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
  }];

  let dependentDialects = [
    "arith::ArithDialect" // For materializeConstant()
  ];

  let useDefaultAttributePrinterParser = 1;
  let useDefaultTypePrinterParser = 1;
  let hasConstantMaterializer = 1;
}

def Mesh_MeshAxis : I<16>;
def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;

//===----------------------------------------------------------------------===//
// Mesh Enums.
//===----------------------------------------------------------------------===//

def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
  "Reduction of an iterator/mesh dimension.", [
  I32EnumAttrCase<"Sum", 1, "sum">,
  I32EnumAttrCase<"Max", 2, "max">,
  I32EnumAttrCase<"Min", 3, "min">,
  I32EnumAttrCase<"Product", 4, "product">,
  // Arithmetic mean.
  I32EnumAttrCase<"Average", 5, "average">,
  I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">,
  I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">,
  I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">,
  I32EnumAttrCase<"Generic", 100, "generic">
]> {
  let genSpecializedAttr = 0;
  let cppNamespace = "::mlir::mesh";
}

def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
  let assemblyFormat = "$value";
}

class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
                   string baseCppClass = "::mlir::Type">
    : TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
  let mnemonic = typeMnemonic;
}

def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
  let summary = "sharding definition";
  let assemblyFormat = "";
}

//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//

def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
  let mnemonic = "axisarray";
  let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
  let assemblyFormat = "`[` $axes `]`";
  let extraClassDeclaration = [{
    size_t size() const { return getAxes().size(); }
    auto begin() const { return getAxes().begin(); }
    auto end() const { return getAxes().end(); }
  }];
}

#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD