llvm/llvm/utils/update_any_test_checks.py

#!/usr/bin/env python3

"""Dispatch to update_*_test_checks.py scripts automatically in bulk

Given a list of test files, this script will invoke the correct
update_test_checks-style script, skipping any tests which have not previously
had assertions autogenerated. If test name starts with '@' it's treated as
a name of file containing test list.
"""

from __future__ import print_function

import argparse
import os
import re
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor

RE_ASSERTIONS = re.compile(
    r"NOTE: Assertions have been autogenerated by ([^\s]+)( UTC_ARGS:.*)?$"
)


def find_utc_tool(search_path, utc_name):
    """
    Return the path to the given UTC tool in the search path, or None if not
    found.
    """
    for path in search_path:
        candidate = os.path.join(path, utc_name)
        if os.path.isfile(candidate):
            return candidate
    return None


def run_utc_tool(utc_name, utc_tool, testname):
    result = subprocess.run(
        [utc_tool, testname], stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )
    return (result.returncode, result.stdout, result.stderr)


def read_arguments_from_file(filename):
    try:
        with open(filename, "r") as file:
            return [line.rstrip() for line in file.readlines()]
    except FileNotFoundError:
        print(f"Error: File '{filename}' not found.")
        sys.exit(1)


def expand_listfile_args(arg_list):
    exp_arg_list = []
    for arg in arg_list:
        if arg.startswith("@"):
            exp_arg_list += read_arguments_from_file(arg[1:])
        else:
            exp_arg_list.append(arg)
    return exp_arg_list


def main():
    from argparse import RawTextHelpFormatter

    parser = argparse.ArgumentParser(
        description=__doc__, formatter_class=RawTextHelpFormatter
    )
    parser.add_argument(
        "--jobs",
        "-j",
        default=1,
        type=int,
        help="Run the given number of jobs in parallel",
    )
    parser.add_argument(
        "--utc-dir",
        nargs="*",
        help="Additional directories to scan for update_*_test_checks scripts",
    )
    parser.add_argument("tests", nargs="+")
    config = parser.parse_args()

    if config.utc_dir:
        utc_search_path = config.utc_dir[:]
    else:
        utc_search_path = []
    script_name = os.path.abspath(__file__)
    utc_search_path.append(os.path.join(os.path.dirname(script_name), os.path.pardir))

    not_autogenerated = []
    utc_tools = {}
    have_error = False

    tests = expand_listfile_args(config.tests)

    with ThreadPoolExecutor(max_workers=config.jobs) as executor:
        jobs = []

        for testname in tests:
            with open(testname, "r") as f:
                header = f.readline().strip()
                m = RE_ASSERTIONS.search(header)
                if m is None:
                    not_autogenerated.append(testname)
                    continue

                utc_name = m.group(1)
                if utc_name not in utc_tools:
                    utc_tools[utc_name] = find_utc_tool(utc_search_path, utc_name)
                    if not utc_tools[utc_name]:
                        print(
                            f"{utc_name}: not found (used in {testname})",
                            file=sys.stderr,
                        )
                        have_error = True
                        continue

                future = executor.submit(
                    run_utc_tool, utc_name, utc_tools[utc_name], testname
                )
                jobs.append((testname, future))

        for testname, future in jobs:
            return_code, stdout, stderr = future.result()

            print(f"Update {testname}")
            stdout = stdout.decode(errors="replace")
            if stdout:
                print(stdout, end="")
                if not stdout.endswith("\n"):
                    print()

            stderr = stderr.decode(errors="replace")
            if stderr:
                print(stderr, end="")
                if not stderr.endswith("\n"):
                    print()
            if return_code != 0:
                print(f"Return code: {return_code}")
                have_error = True

    if have_error:
        sys.exit(1)

    if not_autogenerated:
        print("Tests without autogenerated assertions:")
        for testname in not_autogenerated:
            print(f"  {testname}")


if __name__ == "__main__":
    main()