llvm/mlir/python/mlir/dialects/transform/structured.py

#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .._structured_transform_ops_gen import *
from .._structured_transform_ops_gen import _Dialect
from .._structured_transform_enum_gen import *

try:
    from ...ir import *
    from ...dialects import transform
    from .._ods_common import (
        DynamicIndexList,
        IntOrAttrList,
        MixedValues,
        OptionalBoolList,
        OptionalIntList,
        _cext as _ods_cext,
        _dispatch_dynamic_index_list,
        _dispatch_mixed_values,
        _get_int_array_array_attr,
        _get_int_array_attr,
        _get_value_list,
        _get_value_or_attribute_value,
    )
except ImportError as e:
    raise RuntimeError("Error loading imports from extension module") from e

from typing import List, Optional, Sequence, Union, overload


@_ods_cext.register_operation(_Dialect, replace=True)
class BufferizeToAllocationOp(BufferizeToAllocationOp):
    """Specialization for BufferizeToAllocationOp class."""

    def __init__(
        self,
        target: Union[Operation, OpView, Value],
        *,
        memory_space: Optional[Union[int, str, Attribute]] = None,
        memcpy_op: Optional[str] = None,
        alloc_op: Optional[str] = None,
        bufferize_destination_only: Optional[bool] = None,
        loc=None,
        ip=None,
    ):
        # No other types are allowed, so hard-code those here.
        allocated_buffer_type = transform.AnyValueType.get()
        new_ops_type = transform.AnyOpType.get()

        if isinstance(memory_space, int):
            memory_space = str(memory_space)
        if isinstance(memory_space, str):
            memory_space = Attribute.parse(memory_space)

        super().__init__(
            allocated_buffer_type,
            new_ops_type,
            target,
            memory_space=memory_space,
            memcpy_op=memcpy_op,
            alloc_op=alloc_op,
            bufferize_destination_only=bufferize_destination_only,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class DecomposeOp(DecomposeOp):
    """Specialization for DecomposeOp class."""

    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
        transformed_type = transform.AnyOpType.get()
        super().__init__(transformed_type, target, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class FuseIntoContainingOp(FuseIntoContainingOp):
    """Specialization for FuseIntoContainingOp class."""

    @overload
    def __init__(
        self,
        fused_op_type: Type,
        new_containing_op_type: Type,
        producer_op: Union[Operation, OpView, Value],
        containing_op: Union[Operation, OpView, Value],
        *,
        loc=None,
        ip=None,
    ):
        ...

    @overload
    def __init__(
        self,
        producer_op: Union[Operation, OpView, Value],
        containing_op: Union[Operation, OpView, Value],
        *,
        loc=None,
        ip=None,
    ):
        ...

    def __init__(
        self,
        fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
        new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
        producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
        containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
        *,
        loc=None,
        ip=None,
    ):
        if isinstance(fused_op_type_or_producer_op, Type):
            if not isinstance(new_containing_op_type_or_containing_op, Type):
                raise TypeError(
                    "If 'fused_op_type_or_producer_op' is a type, then "
                    "'new_containing_op_type_or_containing_op' is expected "
                    "to be one as well."
                )
            fused_op_type = fused_op_type_or_producer_op
            new_containing_op_type = new_containing_op_type_or_containing_op
            producer_op = producer_op_or_none
            containing_op = containing_op_or_none
        else:
            fused_op_type = transform.AnyOpType.get()
            new_containing_op_type = transform.AnyOpType.get()
            producer_op = fused_op_type_or_producer_op
            containing_op = new_containing_op_type_or_containing_op

        super().__init__(
            fused_op_type,
            new_containing_op_type,
            producer_op,
            containing_op,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class GeneralizeOp(GeneralizeOp):
    """Specialization for GeneralizeOp class."""

    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
        transformed_type = transform.AnyOpType.get()
        super().__init__(transformed_type, target, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class InterchangeOp(InterchangeOp):
    """Specialization for InterchangeOp class."""

    def __init__(
        self,
        target: Union[Operation, Value],
        *,
        iterator_interchange: OptionalIntList = None,
        loc=None,
        ip=None,
    ):
        transformed_type = transform.AnyOpType.get()
        super().__init__(
            transformed_type,
            target,
            iterator_interchange=iterator_interchange,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class MapCopyToThreadsOp(MapCopyToThreadsOp):
    """Specialization for MapCopyToThreadsOp class."""

    @overload
    def __init__(
        self,
        forall_op_type: Type,
        tiled_op_type: Type,
        target: Union[Operation, OpView, Value],
        *,
        total_num_threads: Union[int, IntegerAttr],
        desired_bit_alignment: Union[int, IntegerAttr],
        loc=None,
        ip=None,
    ):
        ...

    @overload
    def __init__(
        self,
        target: Union[Operation, OpView, Value],
        *,
        total_num_threads: Union[int, IntegerAttr],
        desired_bit_alignment: Union[int, IntegerAttr],
        loc=None,
        ip=None,
    ):
        ...

    def __init__(
        self,
        forall_op_type_or_target: Union[Operation, OpView, Type, Value],
        tiled_op_type_or_none: Optional[Type] = None,
        target_or_none: Optional[Union[Operation, OpView, Value]] = None,
        *,
        total_num_threads: Union[int, IntegerAttr],
        desired_bit_alignment: Union[int, IntegerAttr],
        loc=None,
        ip=None,
    ):
        if isinstance(forall_op_type_or_target, Type):
            forall_op_type = forall_op_type_or_target
            tiled_op_type = tiled_op_type_or_none
            target = target_or_none
        else:
            forall_op_type = transform.AnyOpType.get()
            tiled_op_type = transform.AnyOpType.get()
            target = forall_op_type_or_target

        super().__init__(
            forall_op_type,
            tiled_op_type,
            target,
            total_num_threads=total_num_threads,
            desired_bit_alignment=desired_bit_alignment,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class VectorizeOp(VectorizeOp):
    """Specialization for VectorizeOp class."""

    def __init__(
        self,
        target: Union[Operation, OpView, Value],
        vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
        *,
        vectorize_nd_extract: Optional[bool] = None,
        scalable_sizes: OptionalBoolList = None,
        static_vector_sizes: OptionalIntList = None,
        loc=None,
        ip=None,
    ):
        if (
            scalable_sizes is None
            and static_vector_sizes is None
            and vector_sizes is None
        ):
            dynamic_vector_sizes = []
        elif scalable_sizes is None and static_vector_sizes is None:
            (
                dynamic_vector_sizes,
                static_vector_sizes,
                scalable_sizes,
            ) = _dispatch_dynamic_index_list(vector_sizes)
        elif scalable_sizes is None or static_vector_sizes is None:
            raise TypeError(
                "'scalable_sizes' and 'static_vector_sizes' must either both "
                "be given explicitly or both be given as part of 'vector_sizes'."
            )
        else:
            dynamic_vector_sizes = vector_sizes

        super().__init__(
            target,
            vector_sizes=dynamic_vector_sizes,
            static_vector_sizes=static_vector_sizes,
            scalable_sizes=scalable_sizes,
            vectorize_nd_extract=vectorize_nd_extract,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class MatchOp(MatchOp):
    """Specialization for MatchOp class."""

    @overload
    @classmethod
    def match_op_names(
        cls,
        target: Union[Operation, Value],
        names: Union[str, Sequence[str]],
        *,
        loc=None,
        ip=None,
    ):
        ...

    @overload
    @classmethod
    def match_op_names(
        cls,
        result_type: Type,
        target: Union[Operation, Value],
        names: Union[str, Sequence[str]],
        *,
        loc=None,
        ip=None,
    ):
        ...

    @classmethod
    def match_op_names(
        cls,
        result_type_or_target: Union[Type, Operation, Value],
        target_or_names: Union[Operation, Value, Sequence[str], str],
        names_or_none: Optional[Union[Sequence[str], str]] = None,
        *,
        loc=None,
        ip=None,
    ):
        if isinstance(result_type_or_target, Type):
            result_type = result_type_or_target
            target = target_or_names
            names = names_or_none
        else:
            result_type = transform.AnyOpType.get()
            target = result_type_or_target
            names = target_or_names

        if isinstance(names, str):
            names = [names]

        return cls(
            result_type,
            target,
            ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class MultiTileSizesOp(MultiTileSizesOp):
    """Specialization for MultiTileSizesOp class."""

    def __init__(
        self,
        result_type: Type,
        target: Union[Operation, Value],
        *,
        dimension: Union[int, IntegerAttr],
        target_size: Union[int, IntegerAttr],
        divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
        loc=None,
        ip=None,
    ):
        super().__init__(
            result_type,
            result_type,
            result_type,
            target,
            dimension=dimension,
            target_size=target_size,
            divisor=divisor,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class PadOp(PadOp):
    """Specialization for PadOp class."""

    def __init__(
        self,
        target: Union[Operation, OpView, Value],
        *,
        pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
        padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
        padding_dimensions: OptionalIntList = None,
        pack_paddings: OptionalIntList = None,
        transpose_paddings: Optional[
            Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
        ] = None,
        copy_back_op: Optional[Union[str, StringAttr]] = None,
        loc=None,
        ip=None,
    ):
        if pad_to_multiple_of is None:
            dynamic_pad_to_multiple_of = []
            static_pad_to_multiple_of = None
        else:
            (
                dynamic_pad_to_multiple_of,
                static_pad_to_multiple_of,
                _,
            ) = _dispatch_dynamic_index_list(pad_to_multiple_of)

        transpose_paddings = _get_int_array_array_attr(transpose_paddings)

        any_op_type = transform.AnyOpType.get()
        super().__init__(
            any_op_type,
            any_op_type,
            any_op_type,
            target,
            pad_to_multiple_of=dynamic_pad_to_multiple_of,
            padding_values=padding_values,
            padding_dimensions=padding_dimensions,
            static_pad_to_multiple_of=static_pad_to_multiple_of,
            pack_paddings=pack_paddings,
            transpose_paddings=transpose_paddings,
            copy_back_op=copy_back_op,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class ScalarizeOp(ScalarizeOp):
    """Specialization for ScalarizeOp class."""

    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
        result_type = transform.AnyOpType.get()
        super().__init__(result_type, target, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class SplitOp(SplitOp):
    """Specialization for SplitOp class."""

    def __init__(
        self,
        target: Union[Operation, Value],
        dimension: Union[int, Attribute],
        chunk_sizes: Union[int, Operation, Value, Attribute],
        *,
        loc=None,
        ip=None,
    ):
        if isinstance(chunk_sizes, int):
            static_chunk_sizes = chunk_sizes
            dynamic_chunk_sizes = None
        else:
            static_chunk_sizes = ShapedType.get_dynamic_size()
            dynamic_chunk_sizes = chunk_sizes

        super().__init__(
            target.type,
            target.type,
            target,
            dimension=dimension,
            static_chunk_sizes=static_chunk_sizes,
            dynamic_chunk_sizes=dynamic_chunk_sizes,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class TileUsingForOp(TileUsingForOp):
    """Specialization for TileUsingForOp class."""

    @overload
    def __init__(
        self,
        loop_types: Union[Type, List[Type]],
        target: Union[Operation, Value],
        *,
        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
        interchange: OptionalIntList = None,
        loc=None,
        ip=None,
    ):
        ...

    @overload
    def __init__(
        self,
        target: Union[Operation, Value, OpView],
        *,
        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
        interchange: OptionalIntList = None,
        loc=None,
        ip=None,
    ):
        ...

    def __init__(
        self,
        loop_types_or_target: Union[Type, List[Type], Operation, Value],
        target_or_none: Optional[Union[Operation, Value, OpView]] = None,
        *,
        sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
        interchange: OptionalIntList = None,
        loc=None,
        ip=None,
    ):
        (
            dynamic_sizes,
            static_sizes,
            scalable_sizes,
        ) = _dispatch_dynamic_index_list(sizes)

        num_loops = sum(v if v == 0 else 1 for v in static_sizes)

        if isinstance(loop_types_or_target, (Operation, Value, OpView)):
            loop_types = [transform.AnyOpType.get()] * num_loops
            target = loop_types_or_target
            assert (
                target_or_none is None
            ), "Cannot construct TileUsingForOp with two targets."
        else:
            loop_types = (
                ([loop_types_or_target] * num_loops)
                if isinstance(loop_types_or_target, Type)
                else loop_types_or_target
            )
            target = target_or_none

        super().__init__(
            target.type,
            loop_types,
            target,
            dynamic_sizes=dynamic_sizes,
            static_sizes=static_sizes,
            interchange=interchange,
            scalable_sizes=scalable_sizes,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class TileUsingForallOp(TileUsingForallOp):
    """Specialization for TileUsingForallOp class."""

    @overload
    def __init__(
        self,
        loops_type: Type,
        tiled_op_type: Type,
        target: Union[Operation, Value, OpView],
        *,
        num_threads: Optional[MixedValues] = None,
        tile_sizes: MixedValues = None,
        mapping=None,
        loc=None,
        ip=None,
    ):
        ...

    @overload
    def __init__(
        self,
        target: Union[Operation, Value, OpView],
        *,
        num_threads: Optional[MixedValues] = None,
        tile_sizes: MixedValues = None,
        mapping=None,
        loc=None,
        ip=None,
    ):
        ...

    def __init__(
        self,
        loops_type_or_target: Union[
            Type, Union[Operation, Value, OpView]  # loops_type
        ],  # target
        tiled_op_type_or_none: Optional[Type] = None,
        target_or_none: Optional[Union[Operation, Value, OpView]] = None,
        *,
        num_threads: MixedValues = None,
        tile_sizes: MixedValues = None,
        mapping=None,
        loc=None,
        ip=None,
    ):
        # `Type` arguments in the front are optional: add default values to front.
        if isinstance(loops_type_or_target, Type):
            # First overload: type arguments provided.
            if not isinstance(tiled_op_type_or_none, Type):
                raise TypeError(
                    "If 'loops_type_or_target' is a type, then "
                    "'tiled_op_type_or_none' is expected to be one as well."
                )
            loops_type = loops_type_or_target
            tiled_op_type = tiled_op_type_or_none
            target = target_or_none
        else:
            # Last overload: type arguments missing.
            loops_type = transform.AnyOpType.get()
            tiled_op_type = transform.AnyOpType.get()
            target = loops_type_or_target

        # Unpack mixed num_threads.
        (
            dynamic_num_threads,
            packed_num_threads,
            num_threads_attr,
        ) = _dispatch_mixed_values(num_threads)

        # Unpack mixed tile_sizes.
        (
            dynamic_tile_sizes,
            packed_tile_sizes,
            tile_sizes_attr,
        ) = _dispatch_mixed_values(tile_sizes)

        super().__init__(
            loops_type,
            tiled_op_type,
            target=target,
            tile_sizes=dynamic_tile_sizes,
            packed_tile_sizes=packed_tile_sizes,
            static_tile_sizes=tile_sizes_attr,
            num_threads=dynamic_num_threads,
            packed_num_threads=packed_num_threads,
            static_num_threads=num_threads_attr,
            mapping=mapping,
            loc=loc,
            ip=ip,
        )


@_ods_cext.register_operation(_Dialect, replace=True)
class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp):
    """Specialization for VectorizeChildrenAndApplyPatternsOp class."""

    def __init__(
        self,
        target: Union[Operation, Value],
        *,
        disable_multi_reduction_to_contract_patterns: bool = False,
        disable_transfer_permutation_map_lowering_patterns: bool = False,
        vectorize_nd_extract: bool = False,
        vectorize_padding: bool = False,
        loc=None,
        ip=None,
    ):
        transformed_type = transform.AnyOpType.get()
        super().__init__(
            transformed_type,
            target,
            disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
            disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
            vectorize_nd_extract=vectorize_nd_extract,
            vectorize_padding=vectorize_padding,
            loc=loc,
            ip=ip,
        )