# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._bufferization_transform_ops_gen import *
from .._bufferization_transform_ops_gen import _Dialect
try:
from ...ir import *
from ...dialects import transform
from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from enum import Enum
from typing import Optional, overload, Union
@_ods_cext.register_operation(_Dialect, replace=True)
class EmptyTensorToAllocTensorOp(EmptyTensorToAllocTensorOp):
"""Specialization for EmptyTensorToAllocTensorOp class."""
@overload
def __init__(
self,
transformed_type: Type,
target: Union[Operation, OpView, Value],
*,
loc=None,
ip=None,
):
...
@overload
def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
...
def __init__(
self,
transformed_type_or_target: Type,
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
loc=None,
ip=None,
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
target = target_or_none
else:
transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
target = transformed_type_or_target
super().__init__(
transformed_type,
target,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class OneShotBufferizeOp(OneShotBufferizeOp):
"""Specialization for OneShotBufferizeOp class."""
@overload
def __init__(
self,
transformed_type: Type,
target: Union[Operation, OpView, Value],
*,
allow_return_allocs_from_loops: Optional[bool] = None,
allow_unknown_ops: Optional[bool] = None,
bufferize_function_boundaries: Optional[bool] = None,
function_boundary_type_conversion: Optional[Enum] = None,
memcpy_op: Optional[str] = None,
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
allow_return_allocs_from_loops: Optional[bool] = None,
allow_unknown_ops: Optional[bool] = None,
bufferize_function_boundaries: Optional[bool] = None,
function_boundary_type_conversion: Optional[Enum] = None,
memcpy_op: Optional[str] = None,
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
ip=None,
):
...
def __init__(
self,
transformed_type_or_target: Type,
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
allow_return_allocs_from_loops: Optional[bool] = None,
allow_unknown_ops: Optional[bool] = None,
bufferize_function_boundaries: Optional[bool] = None,
function_boundary_type_conversion: Optional[Enum] = None,
memcpy_op: Optional[str] = None,
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
ip=None,
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
target = target_or_none
else:
transformed_type = transform.AnyOpType.get()
target = transformed_type_or_target
super().__init__(
transformed_type,
target,
allow_return_allocs_from_loops=allow_return_allocs_from_loops,
allow_unknown_ops=allow_unknown_ops,
bufferize_function_boundaries=bufferize_function_boundaries,
function_boundary_type_conversion=function_boundary_type_conversion,
memcpy_op=memcpy_op,
print_conflicts=print_conflicts,
test_analysis_only=test_analysis_only,
loc=loc,
ip=ip,
)