llvm/mlir/utils/jupyter/mlir_opt_kernel/kernel.py

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from subprocess import Popen
import os
import subprocess
import tempfile
import traceback
from ipykernel.kernelbase import Kernel

__version__ = "0.0.1"


def _get_executable():
    """Find the mlir-opt executable."""

    def is_exe(fpath):
        """Returns whether executable file."""
        return os.path.isfile(fpath) and os.access(fpath, os.X_OK)

    program = os.environ.get("MLIR_OPT_EXECUTABLE", "mlir-opt")
    path, name = os.path.split(program)
    # Attempt to get the executable
    if path:
        if is_exe(program):
            return program
    else:
        for path in os.environ["PATH"].split(os.pathsep):
            file = os.path.join(path, name)
            if is_exe(file):
                return file
    raise OSError("mlir-opt not found, please see README")


class MlirOptKernel(Kernel):
    """Kernel using mlir-opt inside jupyter.

    The reproducer syntax (`// configuration:`) is used to run passes. The
    previous result can be referenced to by using `_` (this variable is reset
    upon error). E.g.,

    ```mlir
    // configuration: --pass
    func.func @foo(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { ... }
    ```

    ```mlir
    // configuration: --next-pass
    _
    ```
    """

    implementation = "mlir"
    implementation_version = __version__

    language_version = __version__
    language = "mlir"
    language_info = {
        "name": "mlir",
        "codemirror_mode": {"name": "mlir"},
        "mimetype": "text/x-mlir",
        "file_extension": ".mlir",
        "pygments_lexer": "text",
    }

    @property
    def banner(self):
        """Returns kernel banner."""
        # Just a placeholder.
        return "mlir-opt kernel %s" % __version__

    def __init__(self, **kwargs):
        Kernel.__init__(self, **kwargs)
        self._ = None
        self.executable = None
        self.silent = False

    def get_executable(self):
        """Returns the mlir-opt executable path."""
        if not self.executable:
            self.executable = _get_executable()
        return self.executable

    def process_output(self, output):
        """Reports regular command output."""
        if not self.silent:
            # Send standard output
            stream_content = {"name": "stdout", "text": output}
            self.send_response(self.iopub_socket, "stream", stream_content)

    def process_error(self, output):
        """Reports error response."""
        if not self.silent:
            # Send standard error
            stream_content = {"name": "stderr", "text": output}
            self.send_response(self.iopub_socket, "stream", stream_content)

    def do_execute(
        self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
    ):
        """Execute user code using mlir-opt binary."""

        def ok_status():
            """Returns OK status."""
            return {
                "status": "ok",
                "execution_count": self.execution_count,
                "payload": [],
                "user_expressions": {},
            }

        def run(code):
            """Run the code by pipeing via filesystem."""
            try:
                inputmlir = tempfile.NamedTemporaryFile(delete=False)
                command = [
                    # Specify input and output file to error out if also
                    # set as arg.
                    self.get_executable(),
                    "--color",
                    inputmlir.name,
                    "-o",
                    "-",
                ]
                # Simple handling of repeating last line.
                if code.endswith("\n_"):
                    if not self._:
                        raise NameError("No previous result set")
                    code = code[:-1] + self._
                inputmlir.write(code.encode("utf-8"))
                inputmlir.close()
                pipe = Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                output, errors = pipe.communicate()
                exitcode = pipe.returncode
            finally:
                os.unlink(inputmlir.name)

            # Replace temporary filename with placeholder. This takes the very
            # remote chance where the full input filename (generated above)
            # overlaps with something in the dump unrelated to the file.
            fname = inputmlir.name.encode("utf-8")
            output = output.replace(fname, b"<<input>>")
            errors = errors.replace(fname, b"<<input>>")
            return output, errors, exitcode

        self.silent = silent
        if not code.strip():
            return ok_status()

        try:
            output, errors, exitcode = run(code)

            if exitcode:
                self._ = None
            else:
                self._ = output.decode("utf-8")
        except KeyboardInterrupt:
            return {"status": "abort", "execution_count": self.execution_count}
        except Exception as error:
            # Print traceback for local debugging.
            traceback.print_exc()
            self._ = None
            exitcode = 255
            errors = repr(error).encode("utf-8")

        if exitcode:
            content = {"ename": "", "evalue": str(exitcode), "traceback": []}

            self.send_response(self.iopub_socket, "error", content)
            self.process_error(errors.decode("utf-8"))

            content["execution_count"] = self.execution_count
            content["status"] = "error"
            return content

        if not silent:
            data = {}
            data["text/x-mlir"] = self._
            content = {
                "execution_count": self.execution_count,
                "data": data,
                "metadata": {},
            }
            self.send_response(self.iopub_socket, "execute_result", content)
            self.process_output(self._)
            self.process_error(errors.decode("utf-8"))
        return ok_status()