llvm/mlir/python/mlir/dialects/transform/interpreter/__init__.py

#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ....ir import Operation
from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter

TransformOptions = _cextTransformInterpreter.TransformOptions


def _unpack_operation(op):
    if isinstance(op, Operation):
        return op
    return op.operation


def apply_named_sequence(
    payload_root, transform_root, transform_module, transform_options=None
):
    """Applies the transformation script starting at the given transform root
    operation to the given payload operation. The module containing the
    transform root as well as the transform options should be provided.
    The transform operation must implement TransformOpInterface and the module
    must be a ModuleOp."""

    args = tuple(
        map(_unpack_operation, (payload_root, transform_root, transform_module))
    )
    if transform_options is None:
        _cextTransformInterpreter.apply_named_sequence(*args)
    else:
        _cextTransformInterpreter.apply_named_sequence(*args, transform_options)


def copy_symbols_and_merge_into(target, other):
    """Copies symbols from other into target, renaming private symbols to avoid
    duplicates. Raises an error if copying would lead to duplicate public
    symbols."""
    _cextTransformInterpreter.copy_symbols_and_merge_into(
        _unpack_operation(target), _unpack_operation(other)
    )