llvm/mlir/python/mlir/dialects/transform/gpu.py

#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .._gpu_transform_ops_gen import *
from .._gpu_transform_ops_gen import _Dialect

try:
    from ...ir import *
    from ...dialects import transform
    from .._ods_common import _cext as _ods_cext
except ImportError as e:
    raise RuntimeError("Error loading imports from extension module") from e

from typing import Optional, Sequence, Union, overload


@_ods_cext.register_operation(_Dialect, replace=True)
class MapForallToBlocks(MapForallToBlocks):
    """Specialization for MapForallToBlocks class."""

    @overload
    def __init__(
        self,
        result_type: Type,
        target: Union[Operation, OpView, Value],
        *,
        grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
        generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
        loc=None,
        ip=None,
    ):
        ...

    @overload
    def __init__(
        self,
        target: Union[Operation, OpView, Value],
        *,
        grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
        generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
        loc=None,
        ip=None,
    ):
        ...

    def __init__(
        self,
        result_type_or_target: Union[Operation, OpView, Type, Value],
        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
        *,
        grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
        generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
        loc=None,
        ip=None,
    ):
        if isinstance(result_type_or_target, Type):
            result_type = result_type_or_target
            target = target_or_none
        else:
            result_type = transform.AnyOpType.get()
            target = result_type_or_target

        super().__init__(
            result_type,
            target,
            grid_dims=grid_dims,
            generate_gpu_launch=generate_gpu_launch,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class MapNestedForallToThreads(MapNestedForallToThreads):
    """Specialization for MapNestedForallToThreads class."""

    @overload
    def __init__(
        self,
        result_type: Type,
        target: Union[Operation, OpView, Value],
        *,
        block_dims: Optional[Sequence[int]] = None,
        warp_size: Optional[Sequence[int]] = None,
        sync_after_distribute: Optional[bool] = None,
        loc=None,
        ip=None,
    ):
        ...

    @overload
    def __init__(
        self,
        target: Union[Operation, OpView, Value],
        *,
        block_dims: Optional[Sequence[int]] = None,
        warp_size: Optional[Sequence[int]] = None,
        sync_after_distribute: Optional[bool] = None,
        loc=None,
        ip=None,
    ):
        ...

    def __init__(
        self,
        result_type_or_target: Union[Operation, OpView, Value, Type],
        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
        *,
        block_dims: Optional[Union[Sequence[int], Attribute]] = None,
        warp_size: Optional[Union[Sequence[int], Attribute]] = None,
        sync_after_distribute: Optional[bool] = None,
        loc=None,
        ip=None,
    ):
        if isinstance(result_type_or_target, Type):
            result_type = result_type_or_target
            target = target_or_none
        else:
            result_type = result_type_or_target.type
            target = result_type_or_target
        super().__init__(
            result_type,
            target,
            block_dims=block_dims,
            warp_size=warp_size,
            sync_after_distribute=sync_after_distribute,
            loc=loc,
            ip=ip,
        )