llvm/mlir/python/mlir/dialects/transform/memref.py

#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .._memref_transform_ops_gen import *
from .._memref_transform_ops_gen import _Dialect

try:
    from ...ir import *
    from ...dialects import transform
    from .._ods_common import _cext as _ods_cext
except ImportError as e:
    raise RuntimeError("Error loading imports from extension module") from e

from typing import Optional, overload, Union


@_ods_cext.register_operation(_Dialect, replace=True)
class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp):
    """Specialization for MemRefAllocaToGlobalOp class."""

    @overload
    def __init__(
        self,
        get_global_type: Type,
        global_type: Type,
        alloca: Union[Operation, OpView, Value],
        *,
        loc=None,
        ip=None,
    ):
        ...

    @overload
    def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
        ...

    def __init__(
        self,
        get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
        global_type_or_none: Optional[Type] = None,
        alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
        *,
        loc=None,
        ip=None,
    ):
        if isinstance(get_global_type_or_alloca, Type):
            get_global_type = get_global_type_or_alloca
            global_type = global_type_or_none
            alloca = alloca_or_none
        else:
            get_global_type = transform.AnyOpType.get()
            global_type = transform.AnyOpType.get()
            alloca = get_global_type_or_alloca

        super().__init__(
            get_global_type,
            global_type,
            alloca,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class MemRefMultiBufferOp(MemRefMultiBufferOp):
    """Specialization for MemRefMultiBufferOp class."""

    @overload
    def __init__(
        self,
        transformed_type: Type,
        target: Union[Operation, OpView, Value],
        factor: Union[int, IntegerAttr],
        *,
        skip_analysis: Optional[bool] = None,
        loc=None,
        ip=None,
    ):
        ...

    @overload
    def __init__(
        self,
        target: Union[Operation, OpView, Value],
        factor: Union[int, IntegerAttr],
        *,
        skip_analysis: Optional[bool] = None,
        loc=None,
        ip=None,
    ):
        ...

    def __init__(
        self,
        transformed_type_or_target: Type,
        target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None,
        factor_or_none: Optional[Union[int, IntegerAttr]] = None,
        *,
        skip_analysis: Optional[bool] = None,
        loc=None,
        ip=None,
    ):
        if isinstance(transformed_type_or_target, Type):
            transformed_type = transformed_type_or_target
            target = target_or_factor
            factor = factor_or_none
        else:
            transformed_type = transform.AnyOpType.get()
            target = transformed_type_or_target
            factor = target_or_factor

        super().__init__(
            transformed_type,
            target,
            factor,
            skip_analysis=skip_analysis,
            loc=loc,
            ip=ip,
        )