chromium/third_party/wpt_tools/wpt/tools/wptrunner/wptrunner/environment.py

# mypy: allow-untyped-defs

import contextlib
import errno
import json
import os
import signal
import socket
import sys
import time
from typing import Optional

import mozprocess
from mozlog import get_default_logger, handlers
from mozlog.structuredlog import StructuredLogger

from . import mpcontext
from .wptlogging import LogLevelRewriter, QueueHandler, LogQueueThread

here = os.path.dirname(__file__)
repo_root = os.path.abspath(os.path.join(here, os.pardir, os.pardir, os.pardir))

sys.path.insert(0, repo_root)
from tools import localpaths  # noqa: F401

from wptserve.handlers import StringHandler

serve = None


def do_delayed_imports(logger, test_paths):
    global serve

    serve_root = serve_path(test_paths)
    sys.path.insert(0, serve_root)

    failed = []

    try:
        from tools.serve import serve
    except ImportError:
        failed.append("serve")

    if failed:
        logger.critical(
            "Failed to import %s. Ensure that tests path %s contains web-platform-tests" %
            (", ".join(failed), serve_root))
        sys.exit(1)


def serve_path(test_paths):
    return test_paths["/"].tests_path


def webtranport_h3_server_is_running(host, port, timeout):
    # TODO(bashi): Move the following import to the beginning of this file
    # once WebTransportH3Server is enabled by default.
    from webtransport.h3.webtransport_h3_server import server_is_running  # type: ignore
    return server_is_running(host, port, timeout)


class TestEnvironmentError(Exception):
    pass


def get_server_logger():
    logger = get_default_logger(component="wptserve")
    log_filter = handlers.LogLevelFilter(lambda x: x, "info")
    # Downgrade errors to warnings for the server
    log_filter = LogLevelRewriter(log_filter, ["error"], "warning")
    logger.component_filter = log_filter
    return logger


class ProxyLoggingContext:
    """Context manager object that handles setup and teardown of a log queue
    for handling logging messages from wptserve."""

    def __init__(self, logger):
        mp_context = mpcontext.get_context()
        self.log_queue = mp_context.Queue()
        self.logging_thread = LogQueueThread(self.log_queue, logger)
        self.logger_handler = QueueHandler(self.log_queue)

    def __enter__(self):
        self.logging_thread.start()
        return self.logger_handler

    def __exit__(self, *args):
        self.log_queue.put(None)
        # Wait for thread to shut down but not for too long since it's a daemon
        self.logging_thread.join(1)


class TestEnvironment:
    """Context manager that owns the test environment i.e. the http and
    websockets servers"""
    def __init__(self, test_paths, testharness_timeout_multipler,
                 pause_after_test, debug_test, debug_info, options, ssl_config, env_extras,
                 enable_webtransport=False, mojojs_path=None, inject_script=None,
                 suppress_handler_traceback=None):

        self.test_paths = test_paths
        self.server = None
        self.config_ctx = None
        self.config = None
        self.server_logger = get_server_logger()
        self.server_logging_ctx = ProxyLoggingContext(self.server_logger)
        self.testharness_timeout_multipler = testharness_timeout_multipler
        self.pause_after_test = pause_after_test
        self.debug_test = debug_test
        self.test_server_port = options.pop("test_server_port", True)
        self.debug_info = debug_info
        self.options = options if options is not None else {}

        mp_context = mpcontext.get_context()
        self._stack = contextlib.ExitStack()
        self.cache_manager = mp_context.Manager()
        self.stash = serve.stash.StashServer(mp_context=mp_context)
        self.env_extras = env_extras
        self.env_extras_cms = None
        self.ssl_config = ssl_config
        self.enable_webtransport = enable_webtransport
        self.mojojs_path = mojojs_path
        self.inject_script = inject_script
        self.suppress_handler_traceback = suppress_handler_traceback

    def __enter__(self):
        server_log_handler = self._stack.enter_context(self.server_logging_ctx)
        self.config_ctx = self.build_config()

        self.config = self._stack.enter_context(self.config_ctx)

        self._stack.enter_context(self.stash)
        self._stack.enter_context(self.cache_manager)

        assert self.env_extras_cms is None, (
            "A TestEnvironment object cannot be nested")

        self.env_extras_cms = []

        for env in self.env_extras:
            cm = env(self.options, self.config)
            self._stack.enter_context(cm)
            self.env_extras_cms.append(cm)

        self.servers = serve.start(self.server_logger,
                                   self.config,
                                   self.get_routes(),
                                   mp_context=mpcontext.get_context(),
                                   log_handlers=[server_log_handler],
                                   webtransport_h3=self.enable_webtransport)

        if self.options.get("supports_debugger") and self.debug_info and self.debug_info.interactive:
            self._stack.enter_context(self.ignore_interrupts())
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        for servers in self.servers.values():
            for _, server in servers:
                server.request_shutdown()
        for servers in self.servers.values():
            for _, server in servers:
                server.wait()

        self._stack.__exit__(exc_type, exc_val, exc_tb)
        self.env_extras_cms = None

    @contextlib.contextmanager
    def ignore_interrupts(self):
        prev_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
        try:
            yield
        finally:
            signal.signal(signal.SIGINT, prev_handler)

    def build_config(self):
        override_path = os.path.join(serve_path(self.test_paths), "config.json")

        config = serve.ConfigBuilder(self.server_logger)

        ports = {
            "http": [8000, 8001],
            "http-private": [8002],
            "http-public": [8003],
            "https": [8443, 8444],
            "https-private": [8445],
            "https-public": [8446],
            "ws": [8888],
            "wss": [8889],
            "h2": [9000],
            "webtransport-h3": [11000],
        }
        config.ports = ports

        if os.path.exists(override_path):
            with open(override_path) as f:
                override_obj = json.load(f)
            config.update(override_obj)

        config.check_subdomains = False

        ssl_config = self.ssl_config.copy()
        ssl_config["encrypt_after_connect"] = self.options.get("encrypt_after_connect", False)
        config.ssl = ssl_config

        if "browser_host" in self.options:
            config.browser_host = self.options["browser_host"]

        if "bind_address" in self.options:
            config.bind_address = self.options["bind_address"]

        config.server_host = self.options.get("server_host", None)
        config.doc_root = serve_path(self.test_paths)
        config.inject_script = self.inject_script

        if self.suppress_handler_traceback is not None:
            config.logging["suppress_handler_traceback"] = self.suppress_handler_traceback

        return config

    def get_routes(self):
        route_builder = serve.get_route_builder(
            self.server_logger,
            self.config.aliases,
            self.config)

        for path, format_args, content_type, route in [
                ("testharness_runner.html", {}, "text/html", "/testharness_runner.html"),
                ("print_pdf_runner.html", {}, "text/html", "/print_pdf_runner.html"),
                (os.path.join(here, "..", "..", "third_party", "pdf_js", "pdf.js"), None,
                 "text/javascript", "/_pdf_js/pdf.js"),
                (os.path.join(here, "..", "..", "third_party", "pdf_js", "pdf.worker.js"), None,
                 "text/javascript", "/_pdf_js/pdf.worker.js"),
                (self.options.get("testharnessreport", "testharnessreport.js"),
                 {"output": self.pause_after_test,
                  "timeout_multiplier": self.testharness_timeout_multipler,
                  "explicit_timeout": "true" if self.debug_info is not None else "false",
                  "debug": "true" if self.debug_test else "false"},
                 "text/javascript;charset=utf8",
                 "/resources/testharnessreport.js")]:
            path = os.path.normpath(os.path.join(here, path))
            # Note that .headers. files don't apply to static routes, so we need to
            # readd any static headers here.
            headers = {"Cache-Control": "max-age=3600"}
            route_builder.add_static(path, format_args, content_type, route,
                                     headers=headers)

        route_builder.add_handler("GET", "/resources/testdriver.js", TestdriverLoader())

        for url_base, test_root in self.test_paths.items():
            if url_base == "/":
                continue
            route_builder.add_mount_point(url_base, test_root.tests_path)

        if "/" not in self.test_paths:
            del route_builder.mountpoint_routes["/"]

        if self.mojojs_path:
            route_builder.add_mount_point("/gen/", self.mojojs_path)

        return route_builder.get_routes()

    def ensure_started(self):
        # Pause for a while to ensure that the server has a chance to start
        total_sleep_secs = 60
        each_sleep_secs = 0.5
        end_time = time.time() + total_sleep_secs
        while time.time() < end_time:
            failed, pending = self.test_servers()
            if failed:
                break
            if not pending:
                return
            time.sleep(each_sleep_secs)
        if failed:
            failures = ", ".join(f"{scheme}:{port}" for scheme, port in failed)
            msg = f"Servers failed to start: {failures}"
        else:
            pending = ", ".join(f"{scheme}:{port}" for scheme, port in pending)
            msg = f"Timed out wait for servers to start: {pending}"
        raise OSError(msg)

    def test_servers(self):
        failed = []
        pending = []
        host = self.config["server_host"]
        for scheme, servers in self.servers.items():
            for port, server in servers:
                if not server.is_alive():
                    failed.append((scheme, port))

        if not failed and self.test_server_port:
            # The webtransport-h3 server test blocks (i.e., doesn't fail quickly
            # with "Connection refused" like the sockets do), so testing these
            # first improves the likelihood the non-webtransport-h3 servers are
            # ready by the time they're checked.
            for port, server in self.servers.get("webtransport-h3", []):
                if not webtranport_h3_server_is_running(host, port, timeout=5):
                    pending.append((host, port))

            for scheme, servers in self.servers.items():
                if scheme == "webtransport-h3":
                    continue
                for port, server in servers:
                    s = socket.socket()
                    s.settimeout(0.1)
                    try:
                        s.connect((host, port))
                    except OSError:
                        pending.append((scheme, port))
                    finally:
                        s.close()

        return failed, pending


class TestdriverLoader:
    """A special static handler for serving `/resources/testdriver.js`.

    This handler lazily reads `testdriver{,-extra}.js` so that wptrunner doesn't
    need to pass the entire file contents to child `wptserve` processes, which
    can slow `wptserve` startup by several seconds (crbug.com/1479850).
    """
    def __init__(self):
        self._handler = None

    def __call__(self, request, response):
        if not self._handler:
            data = b""
            with open(os.path.join(repo_root, "resources", "testdriver.js"), "rb") as fp:
                data += fp.read()
            with open(os.path.join(here, "testdriver-extra.js"), "rb") as fp:
                data += fp.read()
            self._handler = StringHandler(data, "text/javascript")
        return self._handler(request, response)


def wait_for_service(logger: StructuredLogger,
                     host: str,
                     port: int,
                     timeout: float = 60,
                     server_process: Optional[mozprocess.ProcessHandler] = None) -> bool:
    """Waits until network service given as a tuple of (host, port) becomes
    available, `timeout` duration is reached, or the `server_process` exits at
    which point ``socket.error`` is raised."""
    addr = (host, port)
    logger.debug(f"Trying to connect to {host}:{port}")
    end = time.time() + timeout
    while end > time.time():
        if server_process is not None and server_process.poll() is not None:
            returncode = server_process.poll()
            logger.debug(
                f"Server process {server_process.pid} exited with "
                f"{returncode}, giving up trying to connect"
            )
            break

        so = socket.socket()
        try:
            so.connect(addr)
        except socket.timeout:
            pass
        except OSError as e:
            if e.errno != errno.ECONNREFUSED:
                raise
        else:
            logger.debug(f"Connected to {host}:{port}")
            return True
        finally:
            so.close()
        time.sleep(0.5)
    raise OSError("Service is unavailable: %s:%i" % addr)