llvm/mlir/python/mlir/dialects/transform/loop.py

#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .._loop_transform_ops_gen import *
from .._loop_transform_ops_gen import _Dialect

try:
    from ...ir import *
    from .._ods_common import (
        get_op_result_or_value as _get_op_result_or_value,
        _cext as _ods_cext,
    )
except ImportError as e:
    raise RuntimeError("Error loading imports from extension module") from e

from typing import Optional, Union


@_ods_cext.register_operation(_Dialect, replace=True)
class LoopOutlineOp(LoopOutlineOp):
    """Extension for LoopOutlineOp."""

    def __init__(
        self,
        function_type: Type,
        call_type: Type,
        target: Union[Operation, Value],
        *,
        func_name: Union[str, StringAttr],
        ip=None,
        loc=None,
    ):
        super().__init__(
            function_type,
            call_type,
            _get_op_result_or_value(target),
            func_name=(
                func_name
                if isinstance(func_name, StringAttr)
                else StringAttr.get(func_name)
            ),
            ip=ip,
            loc=loc,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class LoopPeelOp(LoopPeelOp):
    """Extension for LoopPeelOp."""

    def __init__(
        self,
        main_loop_type: Type,
        remainder_loop_type: Type,
        target: Union[Operation, Value],
        *,
        peel_front: Union[bool, BoolAttr] = False,
        fail_if_already_divisible: Union[bool, BoolAttr] = False,
        ip=None,
        loc=None,
    ):
        super().__init__(
            main_loop_type,
            remainder_loop_type,
            _get_op_result_or_value(target),
            peel_front=(
                peel_front
                if isinstance(peel_front, BoolAttr)
                else BoolAttr.get(peel_front)
            ),
            fail_if_already_divisible=(
                fail_if_already_divisible
                if isinstance(fail_if_already_divisible, BoolAttr)
                else BoolAttr.get(fail_if_already_divisible)
            ),
            ip=ip,
            loc=loc,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class LoopPipelineOp(LoopPipelineOp):
    """Extension for LoopPipelineOp."""

    def __init__(
        self,
        result_type: Type,
        target: Union[Operation, Value],
        *,
        iteration_interval: Optional[Union[int, IntegerAttr]] = None,
        read_latency: Optional[Union[int, IntegerAttr]] = None,
        ip=None,
        loc=None,
    ):
        if iteration_interval is None:
            iteration_interval = 1
        if read_latency is None:
            read_latency = 10
        super().__init__(
            result_type,
            _get_op_result_or_value(target),
            iteration_interval=iteration_interval,
            read_latency=read_latency,
            ip=ip,
            loc=loc,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class LoopUnrollOp(LoopUnrollOp):
    """Extension for LoopUnrollOp."""

    def __init__(
        self,
        target: Union[Operation, Value],
        *,
        factor: Union[int, IntegerAttr],
        ip=None,
        loc=None,
    ):
        super().__init__(
            _get_op_result_or_value(target),
            factor=factor,
            ip=ip,
            loc=loc,
        )