llvm/mlir/python/mlir/dialects/transform/__init__.py

#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .._transform_enum_gen import *
from .._transform_ops_gen import *
from .._transform_ops_gen import _Dialect
from ..._mlir_libs._mlirDialectsTransform import *
from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType

try:
    from ...ir import *
    from .._ods_common import (
        get_op_result_or_value as _get_op_result_or_value,
        get_op_results_or_values as _get_op_results_or_values,
        _cext as _ods_cext,
    )
except ImportError as e:
    raise RuntimeError("Error loading imports from extension module") from e

from typing import Optional, Sequence, Union, NewType


@_ods_cext.register_operation(_Dialect, replace=True)
class CastOp(CastOp):
    def __init__(
        self,
        result_type: Type,
        target: Union[Operation, Value],
        *,
        loc=None,
        ip=None,
    ):
        super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyPatternsOp(ApplyPatternsOp):
    def __init__(
        self,
        target: Union[Operation, Value, OpView],
        *,
        loc=None,
        ip=None,
    ):
        super().__init__(target, loc=loc, ip=ip)
        self.regions[0].blocks.append()

    @property
    def patterns(self) -> Block:
        return self.regions[0].blocks[0]


@_ods_cext.register_operation(_Dialect, replace=True)
class GetParentOp(GetParentOp):
    def __init__(
        self,
        result_type: Type,
        target: Union[Operation, Value],
        *,
        isolated_from_above: bool = False,
        op_name: Optional[str] = None,
        deduplicate: bool = False,
        nth_parent: int = 1,
        loc=None,
        ip=None,
    ):
        super().__init__(
            result_type,
            _get_op_result_or_value(target),
            isolated_from_above=isolated_from_above,
            op_name=op_name,
            deduplicate=deduplicate,
            nth_parent=nth_parent,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class MergeHandlesOp(MergeHandlesOp):
    def __init__(
        self,
        handles: Sequence[Union[Operation, Value]],
        *,
        deduplicate: bool = False,
        loc=None,
        ip=None,
    ):
        super().__init__(
            [_get_op_result_or_value(h) for h in handles],
            deduplicate=deduplicate,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class ReplicateOp(ReplicateOp):
    def __init__(
        self,
        pattern: Union[Operation, Value],
        handles: Sequence[Union[Operation, Value]],
        *,
        loc=None,
        ip=None,
    ):
        super().__init__(
            [_get_op_result_or_value(h).type for h in handles],
            _get_op_result_or_value(pattern),
            [_get_op_result_or_value(h) for h in handles],
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class SequenceOp(SequenceOp):
    def __init__(
        self,
        failure_propagation_mode,
        results: Sequence[Type],
        target: Union[Operation, Value, Type],
        extra_bindings: Optional[
            Union[Sequence[Value], Sequence[Type], Operation, OpView]
        ] = None,
    ):
        root = (
            _get_op_result_or_value(target)
            if isinstance(target, (Operation, Value))
            else None
        )
        root_type = root.type if not isinstance(target, Type) else target

        if extra_bindings is None:
            extra_bindings = []
        if isinstance(extra_bindings, (Operation, OpView)):
            extra_bindings = _get_op_results_or_values(extra_bindings)

        extra_binding_types = []
        if len(extra_bindings) != 0:
            if isinstance(extra_bindings[0], Type):
                extra_binding_types = extra_bindings
                extra_bindings = []
            else:
                extra_binding_types = [v.type for v in extra_bindings]

        super().__init__(
            results_=results,
            failure_propagation_mode=failure_propagation_mode,
            root=root,
            extra_bindings=extra_bindings,
        )
        self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))

    @property
    def body(self) -> Block:
        return self.regions[0].blocks[0]

    @property
    def bodyTarget(self) -> Value:
        return self.body.arguments[0]

    @property
    def bodyExtraArgs(self) -> BlockArgumentList:
        return self.body.arguments[1:]


@_ods_cext.register_operation(_Dialect, replace=True)
class NamedSequenceOp(NamedSequenceOp):
    def __init__(
        self,
        sym_name,
        input_types: Sequence[Type],
        result_types: Sequence[Type],
        sym_visibility=None,
        arg_attrs=None,
        res_attrs=None,
    ):
        function_type = FunctionType.get(input_types, result_types)
        super().__init__(
            sym_name=sym_name,
            function_type=TypeAttr.get(function_type),
            sym_visibility=sym_visibility,
            arg_attrs=arg_attrs,
            res_attrs=res_attrs,
        )
        self.regions[0].blocks.append(*input_types)

    @property
    def body(self) -> Block:
        return self.regions[0].blocks[0]

    @property
    def bodyTarget(self) -> Value:
        return self.body.arguments[0]

    @property
    def bodyExtraArgs(self) -> BlockArgumentList:
        return self.body.arguments[1:]


@_ods_cext.register_operation(_Dialect, replace=True)
class YieldOp(YieldOp):
    def __init__(
        self,
        operands: Optional[Union[Operation, Sequence[Value]]] = None,
        *,
        loc=None,
        ip=None,
    ):
        if operands is None:
            operands = []
        super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)


AnyOpTypeT = NewType("AnyOpType", AnyOpType)


def any_op_t() -> AnyOpTypeT:
    return AnyOpTypeT(AnyOpType.get())