llvm/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py

#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Model classes representing a tensor comprehension.

These classes model the language more at an AST level as evaluated. Reasoning
about it typically involves processing this form into config objects that
represent actual op definitions (i.e. YAML).
"""

from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from enum import Enum

from ..... import ir as _ir
from .affine import *
from .scalar_expr import *
from .types import *
from .yaml_helper import *

###############################################################################
# Tensor expression nodes.
###############################################################################


class TensorExpression:
    """An expression that can appear on the RHS of a comprehension."""

    def to_scalar_expression(self) -> ScalarExpression:
        raise NotImplementedError()

    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
        """Visits all tensor expression reachable by the expression."""
        callback(self)

    def collect_dim_uses(self, uses: Set["DimDef"]):
        """Collects all DimDefs reachable through this expression."""

        def visit_dim_def(dim_def: AffineExprDef):
            if isinstance(dim_def, DimDef):
                uses.add(dim_def)

        def visit_affine_exprs(expr: "TensorExpression"):
            if isinstance(expr, TensorUse):
                for ind in expr.indices:
                    ind.visit_affine_exprs(visit_dim_def)
            if isinstance(expr, TensorReduceFn):
                for ind in expr.reduce_fn.reduce_dims:
                    ind.visit_affine_exprs(visit_dim_def)

        self.visit_tensor_exprs(visit_affine_exprs)

    def collect_tensor_uses(self, uses: Set["TensorUse"]):
        """Collects all TensorUses reachable through this expression."""

        def visit_tensor_use(expr: "TensorExpression"):
            if isinstance(expr, TensorUse):
                uses.add(expr)

        self.visit_tensor_exprs(visit_tensor_use)

    def collect_indices(self, indices: Set["index"]):
        """Collects all index accesses reachable through this expression."""

        def visit_index(expr: "TensorExpression"):
            if isinstance(expr, index):
                indices.add(expr)

        self.visit_tensor_exprs(visit_index)

    def collect_scalar_uses(self, uses: Set["ScalarDef"]):
        """Collects all ScalarDefs reachable through this expression."""

        def visit_scalar_def(expr: "TensorExpression"):
            if isinstance(expr, ScalarDef):
                uses.add(expr)

        self.visit_tensor_exprs(visit_scalar_def)

    def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
        return BinaryFn.add(self, rhs)

    def __mul__(self, rhs) -> "TensorExpression":
        return BinaryFn.mul(self, rhs)

    def __sub__(self, rhs) -> "TensorExpression":
        return BinaryFn.sub(self, rhs)

    def __truediv__(self, rhs) -> "TensorExpression":
        return BinaryFn.div(self, rhs)

    def __hash__(self):
        return hash(id(self))


class TensorUse(TensorExpression):
    """A used tensor represented by its (tensor_name, indices).

    Note that forming a comprehension via direct assignment is performed through
    __setitem__ on the TensorDef level. However, performing a reduction with
    compound ops (+=, *=, etc) is done by doing a:
      TensorDef.__getitem__
      TensorUse.__iadd__
      TensorDef.__setitem__
    """

    def __init__(self, operand_def: "OperandDef", indices: Sequence[AffineExprDef]):
        self.operand_def = operand_def
        self.indices = tuple(indices)

    def to_scalar_expression(self) -> ScalarExpression:
        return ScalarArg(self.tensor_name).expr()

    @property
    def tensor_name(self) -> str:
        name = self.operand_def.name
        assert name is not None, "TensorDef not registered with an op"
        return name

    def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
        # Computes the reduction dims for implicit reductions. Assumes that the rhs
        # is the expression being reduced and self is being reduced into. Any
        # indices referenced on the rhs and not in self are considered reduction
        # dims and will be ordered as encountered on the rhs.
        rhs_dims = set()
        lhs_dims = set()
        rhs.collect_dim_uses(rhs_dims)
        self.collect_dim_uses(lhs_dims)
        return rhs_dims - lhs_dims

    def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
        return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)

    def __repr__(self):
        return (
            f"{self.operand_def.name}" f"[{', '.join([repr(i) for i in self.indices])}]"
        )


class TensorFn(TensorExpression):
    """Application of a tensor function."""

    def __init__(
        self,
        kind: "FunctionKind",
        name: Optional[str],
        operand_def: Optional["OperandDef"],
        type_var: Optional[TypeVar],
        args: Sequence[TensorExpression],
    ):
        if bool(name) + bool(operand_def) != 1:
            raise ValueError("One of 'name', 'operand_def' must be specified")
        self.name = name
        self.kind = kind
        self.operand_def = operand_def
        self.type_var = type_var
        self.args = args

    def to_scalar_expression(self) -> ScalarExpression:
        if self.operand_def:
            assert self.operand_def.name, "TensorFn not registered with an op"
        attr_name = self.operand_def.name if self.operand_def else None
        args = [arg.to_scalar_expression() for arg in self.args]
        return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()

    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
        super().visit_tensor_exprs(callback)
        for arg in self.args:
            arg.visit_tensor_exprs(callback)

    def __repr__(self):
        name = self.operand_def.name if self.operand_def else self.name
        return (
            f"{self.kind.name}.{name}(type_var={self.type_var}, "
            f"args={', '.join(repr(a) for a in self.args)})"
        )


class TensorReduceFn(TensorExpression):
    """Application of a reduction function.

    This captures the lhs (initial value) separately from the rhs.
    """

    def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]):
        self.reduce_use = reduce_use
        self.lhs = None  # type: Optional[TensorUse]
        self.args = args

    def to_scalar_expression(self) -> ScalarExpression:
        if self.lhs is None:
            raise ValueError(
                f"Cannot scalarize a TensorReduceFn that has not been "
                f"bound to its lhs: {self}"
            )
        full_args = [self.lhs.to_scalar_expression()] + [
            arg.to_scalar_expression() for arg in self.args
        ]
        fn_name = None
        attr_name = None
        if self.reduce_use.binary_fn:
            fn_name = self.reduce_use.binary_fn.fn_name
        if self.reduce_use.binary_attr:
            attr_name = self.reduce_use.binary_attr.operand_def.name
        return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, full_args).expr()

    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
        for arg in self.args:
            arg.visit_tensor_exprs(callback)

    def __repr__(self):
        return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"


class const(TensorExpression):
    """Returns the given constant floating point or integer value."""

    def __init__(self, value: Any):
        with _ir.Context():
            if isinstance(value, float):
                self.value = str(_ir.FloatAttr.get_f64(float(value)))
            elif isinstance(value, int):
                self.value = str(
                    _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))
                )
            else:
                raise ValueError(f"const requires int or float but got {type(value)}")

    def to_scalar_expression(self) -> ScalarExpression:
        return ScalarConst(self.value).expr()

    def __repr__(self):
        return f"const({self.value})"


class index(TensorExpression):
    """Returns the iteration index for a given dimension name.

    Resolves the given dimension name to obtain its position in the iteration
    domain of the operation.
    """

    def __init__(self, dim: DimDef):
        self.dim_def = dim
        self.dim = -1

    def resolve_dimension_name(self, affine_state: AffineBuildState):
        self.dim = affine_state.get_dim(self.dim_def.dimname)

    def to_scalar_expression(self) -> ScalarExpression:
        assert self.dim != -1, "Dimension name not resolved"
        return ScalarIndex(self.dim).expr()

    def __repr__(self):
        return f"index({repr(self.dim)})"


###############################################################################
# Function types and function definitions.
###############################################################################


class FunctionKind(Enum):
    UNARY = 0
    BINARY = 1
    TERNARY = 2
    TYPE = 3


class UnaryFnType:
    """Unary function.

    A unary function takes one tensor expression and returns the
    function evaluation result.
    """

    def __init__(self, fn_name: str):
        self.fn_name = fn_name

    def __call__(self, arg: TensorExpression) -> "TensorFn":
        return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])

    def __repr__(self):
        return f"{self.fn_name}"


class UnaryFn:
    """Unary function namespace."""

    exp = UnaryFnType("exp")
    log = UnaryFnType("log")
    abs = UnaryFnType("abs")
    ceil = UnaryFnType("ceil")
    floor = UnaryFnType("floor")
    negf = UnaryFnType("negf")
    reciprocal = UnaryFnType("reciprocal")
    round = UnaryFnType("round")
    sqrt = UnaryFnType("sqrt")
    rsqrt = UnaryFnType("rsqrt")
    square = UnaryFnType("square")
    tanh = UnaryFnType("tanh")
    erf = UnaryFnType("erf")


class BinaryFnType:
    """Binary function.

    A binary function takes two tensor expressions and returns the
    function evaluation result.
    """

    def __init__(self, fn_name: str):
        self.fn_name = fn_name

    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> "TensorFn":
        return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1])

    def __repr__(self):
        return f"{self.fn_name}"


class BinaryFn:
    """Binary function namespace.

    As the integer types are signless, signedness is implement by different
    functions that treat integers as signed or unsigned values.

    Examples:
    - max -> `arith.MaxSIOp`
    - max_unsigned -> `arith.MaxUIOp`
    """

    add = BinaryFnType("add")
    sub = BinaryFnType("sub")
    mul = BinaryFnType("mul")
    div = BinaryFnType("div")
    div_unsigned = BinaryFnType("div_unsigned")
    max_signed = BinaryFnType("max_signed")
    min_signed = BinaryFnType("min_signed")
    max_unsigned = BinaryFnType("max_unsigned")
    min_unsigned = BinaryFnType("min_unsigned")
    powf = BinaryFnType("powf")


class TernaryFnType:
    """Ternary function.

    A ternary function takes three tensor expressions and returns the
    function evaluation result.
    """

    def __init__(self, fn_name: str):
        self.fn_name = fn_name

    def __call__(
        self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression
    ) -> "TensorFn":
        return TensorFn(
            FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2]
        )

    def __repr__(self):
        return f"{self.fn_name}"


class TernaryFn:
    """Ternary function namespace."""

    select = TernaryFnType("select")


class TypeFnType:
    """Type conversion function.

    A type conversion function takes a target type and a tensor expression and
    returns the casted tensor expression.
    """

    def __init__(self, fn_name: str):
        self.fn_name = fn_name

    def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
        return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])

    def __repr__(self):
        return f"{self.fn_name}"


class TypeFn:
    """Type conversion function namespace.

    As the integer types are signless, signedness is implement by different cast
    functions that treat integers as signed (`cast_signed`) or unsigned
    (`cast_unsigned`) values.

    Examples:
    - cast_signed(I32 -> I64) -> `arith.ExtSIOp`
    - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
    """

    cast_signed = TypeFnType("cast_signed")
    cast_unsigned = TypeFnType("cast_unsigned")


class ReduceFnUse:
    """Reduction function use.

    A reduction use specifies the reduction function and dimensions.
    """

    def __init__(
        self,
        binary_fn: Optional[BinaryFnType],
        binary_attr: Optional["BinaryFnAttrDef"],
        *reduce_dims: DimDef,
    ):
        if bool(binary_fn) + bool(binary_attr) != 1:
            raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
        self.binary_fn = binary_fn
        self.binary_attr = binary_attr
        self.reduce_dims = reduce_dims

    def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
        return TensorReduceFn(self, args)

    def __repr__(self):
        fn = self.binary_fn if self.binary_fn else self.binary_attr
        return f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})"


class ReduceFnType:
    """Reduction function.

    A binary function that reduces its RHS into its LHS.
    """

    def __init__(self, binary_fn: BinaryFnType):
        if not isinstance(binary_fn, BinaryFnType):
            raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}")
        self.binary_fn = binary_fn

    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
        return ReduceFnUse(self.binary_fn, None, *reduce_dims)

    def __repr__(self):
        return f"reduce_{repr(self.binary_fn)}"


class ReduceFn:
    add = ReduceFnType(BinaryFn.add)
    mul = ReduceFnType(BinaryFn.mul)
    max_signed = ReduceFnType(BinaryFn.max_signed)
    min_signed = ReduceFnType(BinaryFn.min_signed)
    max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
    min_unsigned = ReduceFnType(BinaryFn.min_unsigned)


###############################################################################
# Operand definitions.
###############################################################################


class OperandKind(Enum):
    INPUT_TENSOR = 0
    SCALAR = 1
    OUTPUT_TENSOR = 2
    INDEX_ATTR = 3
    UNARY_FN_ATTR = 4
    BINARY_FN_ATTR = 5
    TERNARY_FN_ATTR = 6
    TYPE_FN_ATTR = 7


class OperandDef:
    """Definition of an operand passed to an operation.

    Keep the meta information of Tensor, Scalar, and Attribute operands and
    provide the shared registration functionality.
    """

    def __init__(
        self,
        kind: OperandKind,
        type_var: Optional[TypeVar] = None,
        size_exprs: Optional[Sequence[AffineExprDef]] = None,
        index_dims: Optional[Sequence[DimDef]] = None,
        default_indices: Optional[Sequence[int]] = None,
        default_fn: Optional[str] = None,
    ):
        if type_var and not isinstance(type_var, TypeVar):
            raise ValueError(f"OperandDef requires a TypeVar but got {repr(type_var)}")
        self.owner = None  # type: Optional["LinalgOpDef"]
        self.type_var = type_var
        self.size_exprs = size_exprs
        self.index_dims = index_dims
        self.default_indices = default_indices
        self.default_fn = default_fn
        self.kind = kind
        self.name = None  # type: Optional[str]
        self.registered_index = -1  # type: int

    def attach(self, index: int, name: str, owner: "LinalgOpDef"):
        if self.owner:
            raise ValueError(f"OperandDef already registered with an op: {self}")
        self.registered_index = index
        self.name = name
        self.owner = owner

    def is_input(self) -> bool:
        return self.kind == OperandKind.SCALAR or self.kind == OperandKind.INPUT_TENSOR

    def is_tensor(self) -> bool:
        return (
            self.kind == OperandKind.INPUT_TENSOR
            or self.kind == OperandKind.OUTPUT_TENSOR
        )

    def is_attribute(self) -> bool:
        return (
            self.kind == OperandKind.INDEX_ATTR
            or self.kind == OperandKind.UNARY_FN_ATTR
            or self.kind == OperandKind.BINARY_FN_ATTR
            or self.kind == OperandKind.TERNARY_FN_ATTR
            or self.kind == OperandKind.TYPE_FN_ATTR
        )

    def __hash__(self):
        return hash(id(self))

    def __repr__(self):
        return (
            f"{self.name}:OperandDef(kind={self.kind.name}, "
            f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, "
            f"index_dims={self.index_dims}, "
            f"default_indices={self.default_indices}, "
            f"default_fn={self.default_fn})"
        )


class TensorDef:
    """Tensor operand definition.

    Tensor operands are indexed using the associated indexing_map when forwarded
    to the body of the structured op. A unique name identifies the tensor operands
    and an index determines their position in the operation's parameter list. A
    tensor definition takes type, a shape, and an optional flag to mark output
    tensors. Additionally, a tuple of index dimensions may be used to map the
    tensor to the loop dimensions of the operation. This mapping is needed to
    compute the indexing map of shape-only tensors that have no uses.
    """

    def __init__(
        self,
        type_var: TypeVar,
        *shape: AffineExprDef,
        index_dims: Optional[Sequence[DimDef]] = None,
        output: bool = False,
    ):
        if index_dims and len(shape) != len(index_dims):
            raise ValueError(
                f"Expected the shape rank {len(shape)} to match the "
                f"number of index_dims {len(index_dims)}"
            )
        if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
            raise ValueError(
                f"TensorDef requires index dims of type DimDef but " f"got {index_dims}"
            )
        kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
        self.operand_def = OperandDef(
            kind, type_var=type_var, size_exprs=shape, index_dims=index_dims
        )

    def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
        assert self.operand_def.owner, "TensorDef is not registered with an op"
        state = AffineBuildState(
            global_state=self.operand_def.owner._affine_state, allow_new_symbols=False
        )
        if not isinstance(dims, tuple):
            dims = (dims,)  # Handle single subscript case.
        # Special case: (None) is a 0d-scalar use.
        if dims == (None,):
            dims = ()

        exprs = []
        for expr_def in dims:
            if not isinstance(expr_def, AffineExprDef):
                raise KeyError(
                    "A TensorDef can only be subscripted by a tuple of affine dims"
                )
            exprs.append(expr_def)
        return TensorUse(self.operand_def, exprs)

    def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression):
        """Creates a new 1:1 comprehension by binding this tensor to an expression.

        Note that due to the way assignment works in Python, we have to capture
        direct assignment as a setitem on the TensorDef.
        """
        if not isinstance(value, TensorExpression):
            raise ValueError(
                f"Only TensorExpressions can be assigned to TensorDefs. "
                f"Got: {repr(value)}"
            )
        use = self[dims]
        comp = Comprehension((use, value))
        self.operand_def.owner.comprehensions.append(comp)


class ScalarDef(TensorExpression):
    """Scalar operand definition.

    Scalar operands are forwarded to the body of the structured op as they are.
    A unique name identifies the scalars and an index determines their position in
    the operation's parameter list.
    """

    def __init__(self, type_var: TypeVar):
        self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)

    @property
    def scalar_name(self) -> str:
        name = self.operand_def.name
        assert name is not None, "ScalarDef not registered with an op"
        return name

    def to_scalar_expression(self) -> ScalarExpression:
        return ScalarArg(self.scalar_name).expr()


class IndexAttrDef:
    """Index attribute definition.

    Index attributes provide a way to define and set symbols that can be used in
    indexing expressions. Every attribute specifies a tuple of symbols that at
    compile-time are replaced by integer values as well as their default values.
    """

    def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
        if any(not isinstance(size, SymbolDef) for size in sizes):
            raise ValueError(
                f"IndexAttrDef requires sizes of type SymbolDef " f"but got {sizes}"
            )
        if any(not isinstance(default_val, int) for default_val in default):
            raise ValueError(
                f"IndexAttrDef requires default values of type int "
                f"but got {default}"
            )
        if len(sizes) != len(default):
            raise ValueError(
                f"IndexAttrDef expects {len(sizes)} default values "
                f"but got {len(default)}"
            )
        self.operand_def = OperandDef(
            OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default
        )


class UnaryFnAttrDef:
    """Unary function attribute definition.

    Unary function attributes provide a way to make the arithmetic computation
    parametrizable. Every attribute specifies a default unary function
    that may be overwritten at operation instantiation time.
    """

    def __init__(self, default: "UnaryFnType"):
        if not isinstance(default, UnaryFnType):
            raise ValueError(
                f"UnaryFnAttrDef requires default of type UnaryFnType "
                f"but got {default}"
            )
        self.operand_def = OperandDef(
            OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name
        )

    def __call__(self, arg: TensorExpression) -> TensorFn:
        return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])


class BinaryFnAttrDef:
    """Binary function attribute definition.

    Binary function attributes provide a way to make the arithmetic computation
    parametrizable. Every attribute specifies a default binary function
    that may be overwritten at operation instantiation time.
    """

    def __init__(self, default: "BinaryFnType"):
        if not isinstance(default, BinaryFnType):
            raise ValueError(
                f"BinaryFnAttrDef requires default of type BinaryFnType "
                f"but got {default}"
            )
        self.operand_def = OperandDef(
            OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name
        )

    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
        return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, [arg0, arg1])

    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
        return ReduceFnUse(None, self, *reduce_dims)


class TernaryFnAttrDef:
    """Ternary function attribute definition.

    Ternary function attributes provide a way to make the arithmetic computation
    parametrizable. Every attribute specifies a default Ternary function
    that may be overwritten at operation instantiation time.
    """

    def __init__(self, default: "TernaryFnType"):
        if not isinstance(default, TernaryFnType):
            raise ValueError(
                f"TernaryFnAttrDef requires default of type TernaryFnType "
                f"but got {default}"
            )
        self.operand_def = OperandDef(
            OperandKind.TERNARY_FN_ATTR, default_fn=default.fn_name
        )

    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
        return TensorFn(
            FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1]
        )

    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
        return ReduceFnUse(None, self, *reduce_dims)


class TypeFnAttrDef:
    """Type conversion function attribute definition.

    Type conversion function attributes provide a way to make type conversions
    parameterizable. Every attribute specifies a default type conversion function
    that may be overwritten at operation instantiation time.
    """

    def __init__(self, default: "TypeFnType"):
        if not isinstance(default, TypeFnType):
            raise ValueError(
                f"TypeFnAttrDef requires default of type TypeFnType "
                f"but got {default}"
            )
        self.operand_def = OperandDef(
            OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name
        )

    def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
        return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])


###############################################################################
# Operation definition.
###############################################################################


class Comprehension:
    """Represents a single comprehension."""

    def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
        self.definitions = list()  # List[TensorUse]
        self.values = list()  # List[TensorExpression]

        # Find the lhs to reduction rhs.
        for assign, value in bindings:
            if isinstance(value, TensorReduceFn):
                if value.lhs:
                    raise ValueError(f"Reduction expression already assigns: {value}")
                value.lhs = assign
            self.definitions.append(assign)
            self.values.append(value)

    @property
    def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
        """Gets the reduction dims for the comprehension or None."""
        result = set()
        for use in self.values:
            if isinstance(use, TensorReduceFn):
                result.add(use.reduce_use.reduce_dims)
            else:
                result.add(tuple())
        return result

    def __repr__(self):
        if len(self.definitions) > 1:
            defs_repr = f"({', '.join(repr(d) for d in self.definitions)})"
            values_repr = f"({', '.join(repr(v) for v in self.values)})"
        else:
            defs_repr = f"{repr(self.definitions[0])}"
            values_repr = f"{repr(self.values[0])}"

        return f"{defs_repr} = {values_repr}"


class OpInterfaceDef:
    """An interface that an op implements."""

    def __init__(self, cpp_name: str):
        self.cpp_name = cpp_name


ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface")
ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface")
FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")


class OpDefinitionDef:
    """A method that an op implements."""

    def __init__(self, def_name: str):
        self.def_name = def_name


Canonicalizer = OpDefinitionDef("hasCanonicalizer")


class OpMetadataDef(YAMLObject):
    """Metadata about the op (generally not behavior impacting)."""

    yaml_tag = "!LinalgOpMetadata"

    def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]):
        self.name = name
        self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
        self.doc = doc
        self.implements = []  # type: List[OpInterfaceDef]
        self.defines = []  # type: List[OpDefinitionsDef]

    def to_yaml_custom_dict(self):
        d = dict(
            name=self.name,
            cpp_class_name=self.cpp_class_name,
            doc=self.doc,
        )
        if self.implements:
            d["implements"] = [intr.cpp_name for intr in self.implements]
        if self.defines:
            d["defines"] = [defi.def_name for defi in self.defines]
        return d


class LinalgOpDef:
    """Definition of a linalg op."""

    def __init__(
        self, name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None
    ):
        self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc)
        self.registered_operands = dict()  # type: Dict[str, OperandDef]
        self.domain = list()  # type: List[DimDef]
        self.comprehensions = list()  # type: List[Comprehension]
        self._affine_state = AffineBuildState()

    def add_operand(self, name: str, operand: OperandDef):
        """Registers an operand."""
        if name in self.registered_operands:
            raise ValueError(
                f"The operand {name} is already registered "
                f"to {self.registered_operands['name']}"
            )
        structured_op_methods = [
            "inputs",
            "outputs",
            "result_tensors",
            "region",
            "iterator_types",
            "indexing_maps",
            "getRegionBuilder",
            "getLibraryCallName",
        ]
        if operand.is_attribute() and name in structured_op_methods:
            raise ValueError(
                f"The attribute name {name} conflicts with a structured "
                f"op method name"
            )
        # Ensure output tensors are registered after input tensors and scalars and
        # attributes are registered after all other operand types.
        if operand.is_input() and any(
            not op_def.is_input() for op_def in self.registered_operands.values()
        ):
            raise ValueError(f"Input {name} registered after an output or attribute")
        if operand.kind == OperandKind.OUTPUT_TENSOR and any(
            op_def.is_attribute() for op_def in self.registered_operands.values()
        ):
            raise ValueError(f"Output {name} registered after an attribute")
        operand.attach(len(self.registered_operands), name, self)
        self.registered_operands[name] = operand

    def __repr__(self):
        lines = [f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"]
        for name, operand in self.registered_operands.items():
            lines.append(f"  {operand}")
        if self.comprehensions:
            lines[-1] += " {"
            for comprehension in self.comprehensions:
                lines.append(f"    {comprehension}")
            lines.append("}")
        return "\n".join(lines)