llvm/mlir/python/mlir/dialects/affine.py

#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._affine_ops_gen import *
from ._affine_ops_gen import _Dialect

try:
    from ..ir import *
    from ._ods_common import (
        get_op_result_or_value as _get_op_result_or_value,
        get_op_results_or_values as _get_op_results_or_values,
        _cext as _ods_cext,
        ResultValueTypeTuple as _ResultValueTypeTuple,
        ResultValueT as _ResultValueT,
        VariadicResultValueT as _VariadicResultValueT,
    )
except ImportError as e:
    raise RuntimeError("Error loading imports from extension module") from e

from typing import Optional, Sequence, Union


@_ods_cext.register_operation(_Dialect, replace=True)
class AffineForOp(AffineForOp):
    """Specialization for the Affine for op class."""

    def __init__(
        self,
        lower_bound: Union[int, _ResultValueT, AffineMap],
        upper_bound: Optional[Union[int, _ResultValueT, AffineMap]],
        step: Optional[Union[int, Attribute]] = None,
        iter_args: Optional[_ResultValueT] = None,
        *,
        lower_bound_operands: Optional[_VariadicResultValueT] = None,
        upper_bound_operands: Optional[_VariadicResultValueT] = None,
        loc=None,
        ip=None,
    ):
        """Creates an Affine `for` operation.

        - `lower_bound` is the affine map to use as lower bound of the loop.
        - `upper_bound` is the affine map to use as upper bound of the loop.
        - `step` is the value to use as loop step.
        - `iter_args` is a list of additional loop-carried arguments or an operation
          producing them as results.
        - `lower_bound_operands` is the list of arguments to substitute the dimensions,
          then symbols in the `lower_bound` affine map, in an increasing order.
        - `upper_bound_operands` is the list of arguments to substitute the dimensions,
          then symbols in the `upper_bound` affine map, in an increasing order.
        """

        if lower_bound_operands is None:
            lower_bound_operands = []
        if upper_bound_operands is None:
            upper_bound_operands = []

        if step is None:
            step = 1

        bounds_operands = [lower_bound_operands, upper_bound_operands]
        bounds = [lower_bound, upper_bound]
        bounds_names = ["lower", "upper"]
        for i, name in enumerate(bounds_names):
            if isinstance(bounds[i], int):
                bounds[i] = AffineMap.get_constant(bounds[i])
            elif isinstance(bounds[i], _ResultValueTypeTuple):
                if len(bounds_operands[i]):
                    raise ValueError(
                        f"Either a concrete {name} bound or an AffineMap in combination "
                        f"with {name} bound operands, but not both, is supported."
                    )
                if (
                    isinstance(bounds[i], (OpView, Operation))
                    and len(bounds[i].results) > 1
                ):
                    raise ValueError(
                        f"Only a single concrete value is supported for {name} bound."
                    )

                bounds_operands[i].append(_get_op_result_or_value(bounds[i]))
                bounds[i] = AffineMap.get_identity(1)

            if not isinstance(bounds[i], AffineMap):
                raise ValueError(
                    f"{name} bound must be int | ResultValueT | AffineMap."
                )
            if len(bounds_operands[i]) != bounds[i].n_inputs:
                raise ValueError(
                    f"Wrong number of {name} bound operands passed to AffineForOp; "
                    + f"Expected {bounds[i].n_inputs}, got {len(bounds_operands[i])}."
                )

        lower_bound, upper_bound = bounds

        if iter_args is None:
            iter_args = []
        iter_args = _get_op_results_or_values(iter_args)

        results = [arg.type for arg in iter_args]
        super().__init__(
            results_=results,
            lowerBoundOperands=_get_op_results_or_values(lower_bound_operands),
            upperBoundOperands=_get_op_results_or_values(upper_bound_operands),
            inits=list(iter_args),
            lowerBoundMap=AffineMapAttr.get(lower_bound),
            upperBoundMap=AffineMapAttr.get(upper_bound),
            step=step,
            loc=loc,
            ip=ip,
        )
        self.regions[0].blocks.append(IndexType.get(), *results)

    @property
    def body(self):
        """Returns the body (block) of the loop."""
        return self.regions[0].blocks[0]

    @property
    def induction_variable(self):
        """Returns the induction variable of the loop."""
        return self.body.arguments[0]

    @property
    def inner_iter_args(self):
        """Returns the loop-carried arguments usable within the loop.

        To obtain the loop-carried operands, use `iter_args`.
        """
        return self.body.arguments[1:]


def for_(
    start,
    stop,
    step=None,
    iter_args: Optional[Sequence[Value]] = None,
    *,
    loc=None,
    ip=None,
):
    for_op = AffineForOp(
        start,
        stop,
        step,
        iter_args=iter_args,
        loc=loc,
        ip=ip,
    )
    iv = for_op.induction_variable
    iter_args = tuple(for_op.inner_iter_args)
    with InsertionPoint(for_op.body):
        if len(iter_args) > 1:
            yield iv, iter_args
        elif len(iter_args) == 1:
            yield iv, iter_args[0]
        else:
            yield iv