llvm/mlir/python/mlir/_mlir_libs/__init__.py

# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Any, Sequence

import os

_this_dir = os.path.dirname(__file__)


def get_lib_dirs() -> Sequence[str]:
    """Gets the lib directory for linking to shared libraries.

    On some platforms, the package may need to be built specially to export
    development libraries.
    """
    return [_this_dir]


def get_include_dirs() -> Sequence[str]:
    """Gets the include directory for compiling against exported C libraries.

    Depending on how the package was build, development C libraries may or may
    not be present.
    """
    return [os.path.join(_this_dir, "include")]


# Perform Python level site initialization. This involves:
#   1. Attempting to load initializer modules, specific to the distribution.
#   2. Defining the concrete mlir.ir.Context that does site specific
#      initialization.
#
# Aside from just being far more convenient to do this at the Python level,
# it is actually quite hard/impossible to have such __init__ hooks, given
# the pybind memory model (i.e. there is not a Python reference to the object
# in the scope of the base class __init__).
#
# For #1, we:
#   a. Probe for modules named '_mlirRegisterEverything' and
#     '_site_initialize_{i}', where 'i' is a number starting at zero and
#     proceeding so long as a module with the name is found.
#   b. If the module has a 'register_dialects' attribute, it will be called
#     immediately with a DialectRegistry to populate.
#   c. If the module has a 'context_init_hook', it will be added to a list
#     of callbacks that are invoked as the last step of Context
#     initialization (and passed the Context under construction).
#   d. If the module has a 'disable_multithreading' attribute, it will be
#     taken as a boolean. If it is True for any initializer, then the
#     default behavior of enabling multithreading on the context
#     will be suppressed. This complies with the original behavior of all
#     contexts being created with multithreading enabled while allowing
#     this behavior to be changed if needed (i.e. if a context_init_hook
#     explicitly sets up multithreading).
#
# This facility allows downstreams to customize Context creation to their
# needs.

_dialect_registry = None


def get_dialect_registry():
    global _dialect_registry

    if _dialect_registry is None:
        from ._mlir import ir

        _dialect_registry = ir.DialectRegistry()

    return _dialect_registry


def _site_initialize():
    import importlib
    import itertools
    import logging
    from ._mlir import ir

    logger = logging.getLogger(__name__)
    post_init_hooks = []
    disable_multithreading = False

    def process_initializer_module(module_name):
        nonlocal disable_multithreading
        try:
            m = importlib.import_module(f".{module_name}", __name__)
        except ModuleNotFoundError:
            return False
        except ImportError:
            message = (
                f"Error importing mlir initializer {module_name}. This may "
                "happen in unclean incremental builds but is likely a real bug if "
                "encountered otherwise and the MLIR Python API may not function."
            )
            logger.warning(message, exc_info=True)
            return False

        logger.debug("Initializing MLIR with module: %s", module_name)
        if hasattr(m, "register_dialects"):
            logger.debug("Registering dialects from initializer %r", m)
            m.register_dialects(get_dialect_registry())
        if hasattr(m, "context_init_hook"):
            logger.debug("Adding context init hook from %r", m)
            post_init_hooks.append(m.context_init_hook)
        if hasattr(m, "disable_multithreading"):
            if bool(m.disable_multithreading):
                logger.debug("Disabling multi-threading for context")
                disable_multithreading = True
        return True

    # If _mlirRegisterEverything is built, then include it as an initializer
    # module.
    init_module = None
    if process_initializer_module("_mlirRegisterEverything"):
        init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)

    # Load all _site_initialize_{i} modules, where 'i' is a number starting
    # at 0.
    for i in itertools.count():
        module_name = f"_site_initialize_{i}"
        if not process_initializer_module(module_name):
            break

    class Context(ir._BaseContext):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.append_dialect_registry(get_dialect_registry())
            for hook in post_init_hooks:
                hook(self)
            if not disable_multithreading:
                self.enable_multithreading(True)
            # TODO: There is some debate about whether we should eagerly load
            # all dialects. It is being done here in order to preserve existing
            # behavior. See: https://github.com/llvm/llvm-project/issues/56037
            self.load_all_available_dialects()
            if init_module:
                logger.debug(
                    "Registering translations from initializer %r", init_module
                )
                init_module.register_llvm_translations(self)

    ir.Context = Context

    class MLIRError(Exception):
        """
        An exception with diagnostic information. Has the following fields:
          message: str
          error_diagnostics: List[ir.DiagnosticInfo]
        """

        def __init__(self, message, error_diagnostics):
            self.message = message
            self.error_diagnostics = error_diagnostics
            super().__init__(message, error_diagnostics)

        def __str__(self):
            s = self.message
            if self.error_diagnostics:
                s += ":"
            for diag in self.error_diagnostics:
                s += (
                    "\nerror: "
                    + str(diag.location)[4:-1]
                    + ": "
                    + diag.message.replace("\n", "\n  ")
                )
                for note in diag.notes:
                    s += (
                        "\n note: "
                        + str(note.location)[4:-1]
                        + ": "
                        + note.message.replace("\n", "\n  ")
                    )
            return s

    ir.MLIRError = MLIRError


_site_initialize()