llvm/mlir/utils/mbr/mbr/main.py

"""This file contains the main function that's called by the CLI of the library.
"""

import os
import sys
import time

import numpy as np

from discovery import discover_benchmark_modules, get_benchmark_functions
from stats import has_enough_measurements


def main(top_level_path, stop_on_error):
    """Top level function called when the CLI is invoked."""
    if "::" in top_level_path:
        if top_level_path.count("::") > 1:
            raise AssertionError(f"Invalid path {top_level_path}")
        top_level_path, benchmark_function_name = top_level_path.split("::")
    else:
        benchmark_function_name = None

    if not os.path.exists(top_level_path):
        raise AssertionError(f"The top-level path {top_level_path} doesn't exist")

    modules = [module for module in discover_benchmark_modules(top_level_path)]
    benchmark_dicts = []
    for module in modules:
        benchmark_functions = [
            function
            for function in get_benchmark_functions(module, benchmark_function_name)
        ]
        for benchmark_function in benchmark_functions:
            try:
                compiler, runner = benchmark_function()
            except (TypeError, ValueError) as e:
                error_message = (
                    f"Obtaining compiler and runner failed because of {e}."
                    f" Benchmark function '{benchmark_function.__name__}'"
                    f" must return a two-tuple value (compiler, runner)."
                )
                if stop_on_error is False:
                    print(error_message, file=sys.stderr)
                    continue
                else:
                    raise AssertionError(error_message) from e
            measurements_ns = np.array([])
            if compiler:
                start_compile_time_s = time.time()
                try:
                    compiled_callable = compiler()
                except Exception as e:
                    error_message = (
                        f"Compilation of {benchmark_function.__name__} failed"
                        f" because of {e}"
                    )
                    if stop_on_error is False:
                        print(error_message, file=sys.stderr)
                        continue
                    else:
                        raise AssertionError(error_message) from e
                total_compile_time_s = time.time() - start_compile_time_s
                runner_args = (compiled_callable,)
            else:
                total_compile_time_s = 0
                runner_args = ()
            while not has_enough_measurements(measurements_ns):
                try:
                    measurement_ns = runner(*runner_args)
                except Exception as e:
                    error_message = (
                        f"Runner of {benchmark_function.__name__} failed"
                        f" because of {e}"
                    )
                    if stop_on_error is False:
                        print(error_message, file=sys.stderr)
                        # Recover from runner error by breaking out of this loop
                        # and continuing forward.
                        break
                    else:
                        raise AssertionError(error_message) from e
                if not isinstance(measurement_ns, int):
                    error_message = (
                        f"Expected benchmark runner function"
                        f" to return an int, got {measurement_ns}"
                    )
                    if stop_on_error is False:
                        print(error_message, file=sys.stderr)
                        continue
                    else:
                        raise AssertionError(error_message)
                measurements_ns = np.append(measurements_ns, measurement_ns)

            if len(measurements_ns) > 0:
                measurements_s = [t * 1e-9 for t in measurements_ns]
                benchmark_identifier = ":".join(
                    [module.__name__, benchmark_function.__name__]
                )
                benchmark_dicts.append(
                    {
                        "name": benchmark_identifier,
                        "compile_time": total_compile_time_s,
                        "execution_time": list(measurements_s),
                    }
                )

    return benchmark_dicts