chromium/third_party/wpt_tools/wpt/tools/wptrunner/wptrunner/executors/base.py

# mypy: allow-untyped-defs

import base64
import hashlib
import io
import json
import os
import threading
import traceback
import socket
import sys
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, ClassVar, Tuple, Type
from urllib.parse import urljoin, urlsplit, urlunsplit

from . import pytestrunner
from .actions import actions
from .asyncactions import async_actions
from .protocol import Protocol, WdspecProtocol


here = os.path.dirname(__file__)


def executor_kwargs(test_type, test_environment, run_info_data, subsuite, **kwargs):
    timeout_multiplier = kwargs["timeout_multiplier"]
    if timeout_multiplier is None:
        timeout_multiplier = 1

    executor_kwargs = {"server_config": test_environment.config,
                       "timeout_multiplier": timeout_multiplier,
                       "debug_info": kwargs["debug_info"],
                       "subsuite": subsuite.name,
                       "target_platform": run_info_data["os"]}

    if test_type in ("reftest", "print-reftest"):
        executor_kwargs["screenshot_cache"] = test_environment.cache_manager.dict()
        executor_kwargs["reftest_screenshot"] = kwargs["reftest_screenshot"]

    if test_type == "wdspec":
        executor_kwargs["binary"] = kwargs["binary"]
        executor_kwargs["binary_args"] = kwargs["binary_args"].copy()
        executor_kwargs["webdriver_binary"] = kwargs["webdriver_binary"]
        executor_kwargs["webdriver_args"] = kwargs["webdriver_args"].copy()

    # By default the executor may try to cleanup windows after a test (to best
    # associate any problems with the test causing them). If the user might
    # want to view the results, however, the executor has to skip that cleanup.
    if kwargs["pause_after_test"] or kwargs["pause_on_unexpected"]:
        executor_kwargs["cleanup_after_test"] = False
    executor_kwargs["debug_test"] = kwargs["debug_test"]
    return executor_kwargs


def strip_server(url):
    """Remove the scheme and netloc from a url, leaving only the path and any query
    or fragment.

    url - the url to strip

    e.g. http://example.org:8000/tests?id=1#2 becomes /tests?id=1#2"""

    url_parts = list(urlsplit(url))
    url_parts[0] = ""
    url_parts[1] = ""
    return urlunsplit(url_parts)


def server_url(server_config, protocol, subdomain=False):
    scheme = "https" if protocol == "h2" else protocol
    host = server_config["browser_host"]
    if subdomain:
        # The only supported subdomain filename flag is "www".
        host = "{subdomain}.{host}".format(subdomain="www", host=host)
    return "{scheme}://{host}:{port}".format(scheme=scheme, host=host,
        port=server_config["ports"][protocol][0])


class TestharnessResultConverter:
    harness_codes = {0: "OK",
                     1: "ERROR",
                     2: "TIMEOUT",
                     3: "PRECONDITION_FAILED"}

    test_codes = {0: "PASS",
                  1: "FAIL",
                  2: "TIMEOUT",
                  3: "NOTRUN",
                  4: "PRECONDITION_FAILED"}

    def __call__(self, test, result, extra=None):
        """Convert a JSON result into a (TestResult, [SubtestResult]) tuple"""
        result_url, status, message, stack, subtest_results = result
        assert result_url == test.url, (f"Got results from {result_url}, expected {test.url}")
        harness_result = test.make_result(self.harness_codes[status], message, extra=extra, stack=stack)
        return (harness_result,
                [test.make_subtest_result(st_name, self.test_codes[st_status], st_message, st_stack)
                 for st_name, st_status, st_message, st_stack in subtest_results])


testharness_result_converter = TestharnessResultConverter()


def hash_screenshots(screenshots):
    """Computes the sha1 checksum of a list of base64-encoded screenshots."""
    return [hashlib.sha1(base64.b64decode(screenshot)).hexdigest()
            for screenshot in screenshots]


def _ensure_hash_in_reftest_screenshots(extra):
    """Make sure reftest_screenshots have hashes.

    Marionette internal reftest runner does not produce hashes.
    """
    log_data = extra.get("reftest_screenshots")
    if not log_data:
        return
    for item in log_data:
        if not isinstance(item, dict):
            # Skip relation strings.
            continue
        if "hash" not in item:
            item["hash"] = hash_screenshots([item["screenshot"]])[0]


def get_pages(ranges_value, total_pages):
    """Get a set of page numbers to include in a print reftest.

    :param ranges_value: Parsed page ranges as a list e.g. [[1,2], [4], [6,None]]
    :param total_pages: Integer total number of pages in the paginated output.
    :retval: Set containing integer page numbers to include in the comparison e.g.
             for the example ranges value and 10 total pages this would be
             {1,2,4,6,7,8,9,10}"""
    if not ranges_value:
        return set(range(1, total_pages + 1))

    rv = set()

    for range_limits in ranges_value:
        if len(range_limits) == 1:
            range_limits = [range_limits[0], range_limits[0]]

        if range_limits[0] is None:
            range_limits[0] = 1
        if range_limits[1] is None:
            range_limits[1] = total_pages

        if range_limits[0] > total_pages:
            continue
        rv |= set(range(range_limits[0], range_limits[1] + 1))
    return rv


def reftest_result_converter(self, test, result):
    extra = result.get("extra", {})
    _ensure_hash_in_reftest_screenshots(extra)
    return (test.make_result(
        result["status"],
        result["message"],
        extra=extra,
        stack=result.get("stack")), [])


def pytest_result_converter(self, test, data):
    harness_data, subtest_data = data

    if subtest_data is None:
        subtest_data = []

    harness_result = test.make_result(*harness_data)
    subtest_results = [test.make_subtest_result(*item) for item in subtest_data]

    return (harness_result, subtest_results)


def crashtest_result_converter(self, test, result):
    return test.make_result(**result), []


class ExecutorException(Exception):
    def __init__(self, status, message):
        self.status = status
        self.message = message


class TimedRunner:
    def __init__(self, logger, func, protocol, url, timeout, extra_timeout):
        self.func = func
        self.logger = logger
        self.result = None
        self.protocol = protocol
        self.url = url
        self.timeout = timeout
        self.extra_timeout = extra_timeout
        self.result_flag = threading.Event()

    def run(self):
        for setup_fn in [self.set_timeout, self.before_run]:
            err = setup_fn()
            if err:
                self.result = (False, err)
                return self.result

        executor = threading.Thread(target=self.run_func)
        executor.start()

        # Add twice the extra timeout since the called function is expected to
        # wait at least self.timeout + self.extra_timeout and this gives some leeway
        timeout = self.timeout + 2 * self.extra_timeout if self.timeout else None
        finished = self.result_flag.wait(timeout)
        if self.result is None:
            if finished:
                # flag is True unless we timeout; this *shouldn't* happen, but
                # it can if self.run_func fails to set self.result due to raising
                self.result = False, ("INTERNAL-ERROR", "%s.run_func didn't set a result" %
                                      self.__class__.__name__)
            else:
                if self.protocol.is_alive():
                    message = "Executor hit external timeout (this may indicate a hang)\n"
                    if executor.ident in sys._current_frames():
                        # get a traceback for the current stack of the executor thread
                        message += "".join(traceback.format_stack(
                            sys._current_frames()[executor.ident]))
                    self.result = False, ("EXTERNAL-TIMEOUT", message)
                else:
                    self.logger.info("Browser not responding, setting status to CRASH")
                    self.result = False, ("CRASH", None)
        elif self.result[1] is None:
            # We didn't get any data back from the test, so check if the
            # browser is still responsive
            if self.protocol.is_alive():
                self.result = False, ("INTERNAL-ERROR", None)
            else:
                self.logger.info("Browser not responding, setting status to CRASH")
                self.result = False, ("CRASH", None)

        return self.result

    def set_timeout(self):
        raise NotImplementedError

    def before_run(self):
        pass

    def run_func(self):
        raise NotImplementedError


class TestExecutor:
    """Abstract Base class for object that actually executes the tests in a
    specific browser. Typically there will be a different TestExecutor
    subclass for each test type and method of executing tests.

    :param browser: ExecutorBrowser instance providing properties of the
                    browser that will be tested.
    :param server_config: Dictionary of wptserve server configuration of the
                          form stored in TestEnvironment.config
    :param timeout_multiplier: Multiplier relative to base timeout to use
                               when setting test timeout.
    """
    __metaclass__ = ABCMeta

    test_type: ClassVar[str]
    # convert_result is a class variable set to a callable converter
    # (e.g. reftest_result_converter) converting from an instance of
    # URLManifestItem (e.g. RefTest) + type-dependent results object +
    # type-dependent extra data, returning a tuple of Result and list of
    # SubtestResult. For now, any callable is accepted. TODO: Make this type
    # stricter when more of the surrounding code is annotated.
    convert_result: ClassVar[Callable[..., Any]]
    supports_testdriver = False
    supports_jsshell = False
    # Extra timeout to use after internal test timeout at which the harness
    # should force a timeout
    extra_timeout = 5  # seconds


    def __init__(self, logger, browser, server_config, timeout_multiplier=1,
                 debug_info=None, subsuite=None, **kwargs):
        self.logger = logger
        self.runner = None
        self.browser = browser
        self.server_config = server_config
        self.timeout_multiplier = timeout_multiplier
        self.debug_info = debug_info
        self.subsuite = subsuite
        self.last_environment = {"protocol": "http",
                                 "prefs": {}}
        self.protocol = None  # This must be set in subclasses

    def setup(self, runner, protocol=None):
        """Run steps needed before tests can be started e.g. connecting to
        browser instance

        :param runner: TestRunner instance that is going to run the tests.
        :param protocol: protocol connection to reuse if not None"""
        self.runner = runner
        if protocol is not None:
            assert isinstance(protocol, self.protocol_cls)
            self.protocol = protocol
        elif self.protocol is not None:
            self.protocol.setup(runner)

    def teardown(self):
        """Run cleanup steps after tests have finished"""
        if self.protocol is not None:
            self.protocol.teardown()

    def reset(self):
        """Re-initialize internal state to facilitate repeated test execution
        as implemented by the `--rerun` command-line argument."""
        pass

    def run_test(self, test):
        """Run a particular test.

        :param test: The test to run"""
        try:
            if test.environment != self.last_environment:
                self.on_environment_change(test.environment)
            result = self.do_test(test)
        except Exception as e:
            exception_string = traceback.format_exc()
            message = f"Exception in TestExecutor.run:\n{exception_string}"
            self.logger.warning(message)
            result = self.result_from_exception(test, e, exception_string)

        # log result of parent test
        if result[0].status == "ERROR":
            self.logger.debug(result[0].message)

        self.last_environment = test.environment

        self.runner.send_message("test_ended", test, result)

    def server_url(self, protocol, subdomain=False):
        return server_url(self.server_config, protocol, subdomain)

    def test_url(self, test):
        return urljoin(self.server_url(test.environment["protocol"],
                                       test.subdomain), test.url)

    @abstractmethod
    def do_test(self, test):
        """Test-type and protocol specific implementation of running a
        specific test.

        :param test: The test to run."""
        pass

    def on_environment_change(self, new_environment):
        pass

    def result_from_exception(self, test, e, exception_string):
        if hasattr(e, "status") and e.status in test.result_cls.statuses:
            status = e.status
        else:
            status = "INTERNAL-ERROR"
        message = str(getattr(e, "message", ""))
        if message:
            message += "\n"
        message += exception_string
        return test.make_result(status, message), []

    def wait(self):
        return self.protocol.base.wait()


class TestharnessExecutor(TestExecutor):
    convert_result = testharness_result_converter


class RefTestExecutor(TestExecutor):
    convert_result = reftest_result_converter
    is_print = False

    def __init__(self, logger, browser, server_config, timeout_multiplier=1, screenshot_cache=None,
                 debug_info=None, reftest_screenshot="unexpected", **kwargs):
        TestExecutor.__init__(self, logger, browser, server_config,
                              timeout_multiplier=timeout_multiplier,
                              debug_info=debug_info)

        self.screenshot_cache = screenshot_cache
        self.reftest_screenshot = reftest_screenshot


class CrashtestExecutor(TestExecutor):
    convert_result = crashtest_result_converter


class PrintRefTestExecutor(TestExecutor):
    convert_result = reftest_result_converter
    is_print = True


class RefTestImplementation:
    def __init__(self, executor):
        self.timeout_multiplier = executor.timeout_multiplier
        self.executor = executor
        self.subsuite = executor.subsuite
        # Cache of url:(screenshot hash, screenshot). Typically the
        # screenshot is None, but we set this value if a test fails
        # and the screenshot was taken from the cache so that we may
        # retrieve the screenshot from the cache directly in the future
        self.screenshot_cache = self.executor.screenshot_cache
        self.message = None
        self.reftest_screenshot = executor.reftest_screenshot

    def setup(self):
        pass

    def teardown(self):
        pass

    @property
    def logger(self):
        return self.executor.logger

    def get_hash(self, test, viewport_size, dpi, page_ranges):
        key = (self.subsuite, test.url, viewport_size, dpi)

        if key not in self.screenshot_cache:
            success, data = self.get_screenshot_list(test, viewport_size, dpi, page_ranges)

            if not success:
                return False, data

            screenshots = data
            hash_values = hash_screenshots(data)
            self.screenshot_cache[key] = (hash_values, screenshots)

            rv = (hash_values, screenshots)
        else:
            rv = self.screenshot_cache[key]

        self.message.append(f"{test.url} {rv[0]}")
        return True, rv

    def reset(self):
        self.screenshot_cache.clear()

    def check_pass(self, hashes, screenshots, urls, relation, fuzzy):
        """Check if a test passes, and return a tuple of (pass, page_idx),
        where page_idx is the zero-based index of the first page on which a
        difference occurs if any, or None if there are no differences"""

        assert relation in ("==", "!=")
        lhs_hashes, rhs_hashes = hashes
        lhs_screenshots, rhs_screenshots = screenshots

        if len(lhs_hashes) != len(rhs_hashes):
            self.logger.info("Got different number of pages")
            return relation == "!=", -1

        assert len(lhs_screenshots) == len(lhs_hashes) == len(rhs_screenshots) == len(rhs_hashes)

        for (page_idx, (lhs_hash,
                        rhs_hash,
                        lhs_screenshot,
                        rhs_screenshot)) in enumerate(zip(lhs_hashes,
                                                          rhs_hashes,
                                                          lhs_screenshots,
                                                          rhs_screenshots)):
            comparison_screenshots = (lhs_screenshot, rhs_screenshot)
            if not fuzzy or fuzzy == ((0, 0), (0, 0)):
                equal = lhs_hash == rhs_hash
                # sometimes images can have different hashes, but pixels can be identical.
                if not equal:
                    self.logger.info("Image hashes didn't match%s, checking pixel differences" %
                                     ("" if len(hashes) == 1 else " on page %i" % (page_idx + 1)))
                    max_per_channel, pixels_different = self.get_differences(comparison_screenshots,
                                                                             urls)
                    equal = pixels_different == 0 and max_per_channel == 0
            else:
                max_per_channel, pixels_different = self.get_differences(comparison_screenshots,
                                                                         urls,
                                                                         page_idx if len(hashes) > 1 else None)
                allowed_per_channel, allowed_different = fuzzy
                self.logger.info("Allowed %s pixels different, maximum difference per channel %s" %
                                 ("-".join(str(item) for item in allowed_different),
                                  "-".join(str(item) for item in allowed_per_channel)))
                equal = ((pixels_different == 0 and allowed_different[0] == 0) or
                         (max_per_channel == 0 and allowed_per_channel[0] == 0) or
                         (allowed_per_channel[0] <= max_per_channel <= allowed_per_channel[1] and
                          allowed_different[0] <= pixels_different <= allowed_different[1]))
            if not equal:
                return (False if relation == "==" else True, page_idx)
        # All screenshots were equal within the fuzziness
        return (True if relation == "==" else False, -1)

    def get_differences(self, screenshots, urls, page_idx=None):
        from PIL import Image, ImageChops, ImageStat

        lhs = Image.open(io.BytesIO(base64.b64decode(screenshots[0]))).convert("RGB")
        rhs = Image.open(io.BytesIO(base64.b64decode(screenshots[1]))).convert("RGB")
        self.check_if_solid_color(lhs, urls[0])
        self.check_if_solid_color(rhs, urls[1])
        diff = ImageChops.difference(lhs, rhs)
        minimal_diff = diff.crop(diff.getbbox())
        mask = minimal_diff.convert("L", dither=None)
        stat = ImageStat.Stat(minimal_diff, mask)
        per_channel = max(item[1] for item in stat.extrema)
        count = stat.count[0]
        self.logger.info("Found %s pixels different, maximum difference per channel %s%s" %
                         (count,
                          per_channel,
                          "" if page_idx is None else " on page %i" % (page_idx + 1)))
        return per_channel, count

    def check_if_solid_color(self, image, url):
        extrema = image.getextrema()
        if all(min == max for min, max in extrema):
            color = ''.join('%02X' % value for value, _ in extrema)
            self.message.append(f"Screenshot is solid color 0x{color} for {url}\n")

    def run_test(self, test):
        viewport_size = test.viewport_size
        dpi = test.dpi
        page_ranges = test.page_ranges
        self.message = []


        # Depth-first search of reference tree, with the goal
        # of reachings a leaf node with only pass results

        stack = list(((test, item[0]), item[1]) for item in reversed(test.references))

        while stack:
            hashes = [None, None]
            screenshots = [None, None]
            urls = [None, None]

            nodes, relation = stack.pop()
            fuzzy = self.get_fuzzy(test, nodes, relation)

            for i, node in enumerate(nodes):
                success, data = self.get_hash(node, viewport_size, dpi, page_ranges)
                if success is False:
                    return {"status": data[0], "message": data[1]}

                hashes[i], screenshots[i] = data
                urls[i] = node.url

            is_pass, page_idx = self.check_pass(hashes, screenshots, urls, relation, fuzzy)
            log_data = [
                {"url": urls[0], "screenshot": screenshots[0][page_idx],
                 "hash": hashes[0][page_idx]},
                relation,
                {"url": urls[1], "screenshot": screenshots[1][page_idx],
                 "hash": hashes[1][page_idx]}
            ]

            if is_pass:
                fuzzy = self.get_fuzzy(test, nodes, relation)
                if nodes[1].references:
                    stack.extend(list(((nodes[1], item[0]), item[1])
                                      for item in reversed(nodes[1].references)))
                else:
                    test_result = {"status": "PASS", "message": None}
                    if (self.reftest_screenshot == "always" or
                        self.reftest_screenshot == "unexpected" and
                        test.expected() != "PASS"):
                        test_result["extra"] = {"reftest_screenshots": log_data}
                    # We passed
                    return test_result

        # We failed, so construct a failure message

        for i, (node, screenshot) in enumerate(zip(nodes, screenshots)):
            if screenshot is None:
                success, screenshot = self.retake_screenshot(node, viewport_size, dpi, page_ranges)
                if success:
                    screenshots[i] = screenshot

        test_result = {"status": "FAIL",
                       "message": "\n".join(self.message)}
        if (self.reftest_screenshot in ("always", "fail") or
            self.reftest_screenshot == "unexpected" and
            test.expected() != "FAIL"):
            test_result["extra"] = {"reftest_screenshots": log_data}
        return test_result

    def get_fuzzy(self, root_test, test_nodes, relation):
        full_key = tuple([item.url for item in test_nodes] + [relation])
        ref_only_key = test_nodes[1].url

        fuzzy_override = root_test.fuzzy_override
        fuzzy = test_nodes[0].fuzzy

        sources = [fuzzy_override, fuzzy]
        keys = [full_key, ref_only_key, None]
        value = None
        for source in sources:
            for key in keys:
                if key in source:
                    value = source[key]
                    break
            if value:
                break
        return value

    def retake_screenshot(self, node, viewport_size, dpi, page_ranges):
        success, data = self.get_screenshot_list(node,
                                                 viewport_size,
                                                 dpi,
                                                 page_ranges)
        if not success:
            return False, data

        key = (node.url, viewport_size, dpi)
        hash_val, _ = self.screenshot_cache[key]
        self.screenshot_cache[key] = hash_val, data
        return True, data

    def get_screenshot_list(self, node, viewport_size, dpi, page_ranges):
        success, data = self.executor.screenshot(node, viewport_size, dpi, page_ranges)
        if success and not isinstance(data, list):
            return success, [data]
        return success, data


class WdspecExecutor(TestExecutor):
    convert_result = pytest_result_converter
    protocol_cls: ClassVar[Type[Protocol]] = WdspecProtocol

    def __init__(self, logger, browser, server_config, webdriver_binary,
                 webdriver_args, target_platform, timeout_multiplier=1, capabilities=None,
                 debug_info=None, binary=None, binary_args=None, **kwargs):
        super().__init__(logger, browser, server_config,
                         timeout_multiplier=timeout_multiplier,
                         debug_info=debug_info)
        self.webdriver_binary = webdriver_binary
        self.webdriver_args = webdriver_args
        self.timeout_multiplier = timeout_multiplier
        self.capabilities = capabilities
        self.binary = binary
        self.binary_args = binary_args

        # Map OS to WebDriver specific platform names
        os_map = {"win": "windows"}
        self.target_platform = os_map.get(target_platform, target_platform)

    def setup(self, runner, protocol=None):
        assert protocol is None, "Switch executor not allowed for wdspec tests."
        self.protocol = self.protocol_cls(self, self.browser)
        super().setup(runner)

    def is_alive(self):
        return self.protocol.is_alive()

    def on_environment_change(self, new_environment):
        pass

    def do_test(self, test):
        timeout = test.timeout * self.timeout_multiplier + self.extra_timeout

        success, data = WdspecRun(self.do_wdspec,
                                  test.abs_path,
                                  timeout).run()

        if success:
            return self.convert_result(test, data)

        return (test.make_result(*data), [])

    def do_wdspec(self, path, timeout):
        session_config = {"host": self.browser.host,
                          "port": self.browser.port,
                          "capabilities": self.capabilities,
                          "target_platform": self.target_platform,
                          "timeout_multiplier": self.timeout_multiplier,
                          "browser": {
                              "binary": self.binary,
                              "args": self.binary_args,
                              "env": self.browser.env,
                          },
                          "webdriver": {
                              "binary": self.webdriver_binary,
                              "args": self.webdriver_args
                          }}

        return pytestrunner.run(path,
                                self.server_config,
                                session_config,
                                timeout=timeout)


class WdspecRun:
    def __init__(self, func, path, timeout):
        self.func = func
        self.result = (None, None)
        self.path = path
        self.timeout = timeout
        self.result_flag = threading.Event()

    def run(self):
        """Runs function in a thread and interrupts it if it exceeds the
        given timeout.  Returns (True, (Result, [SubtestResult ...])) in
        case of success, or (False, (status, extra information)) in the
        event of failure.
        """

        executor = threading.Thread(target=self._run)
        executor.start()

        self.result_flag.wait(self.timeout)
        if self.result[1] is None:
            self.result = False, ("EXTERNAL-TIMEOUT", None)

        return self.result

    def _run(self):
        try:
            self.result = True, self.func(self.path, self.timeout)
        except (socket.timeout, OSError):
            self.result = False, ("CRASH", None)
        except Exception as e:
            message = getattr(e, "message", "")
            if message:
                message += "\n"
            message += traceback.format_exc()
            self.result = False, ("INTERNAL-ERROR", message)
        finally:
            self.result_flag.set()


class CallbackHandler:
    """Handle callbacks from testdriver-using tests.

    The default implementation here makes sense for things that are roughly like
    WebDriver. Things that are more different to WebDriver may need to create a
    fully custom implementation."""

    unimplemented_exc: ClassVar[Tuple[Type[Exception], ...]] = (NotImplementedError,)
    expected_exc: ClassVar[Tuple[Type[Exception], ...]] = ()

    def __init__(self, logger, protocol, test_window):
        self.protocol = protocol
        self.test_window = test_window
        self.logger = logger
        self.callbacks = {
            "action": self.process_action,
            "complete": self.process_complete
        }

        self.actions = {cls.name: cls(self.logger, self.protocol) for cls in actions}

    def __call__(self, result):
        url, command, payload = result
        self.logger.debug("Got async callback: %s" % result[1])
        try:
            callback = self.callbacks[command]
        except KeyError as e:
            raise ValueError("Unknown callback type %r" % result[1]) from e
        return callback(url, payload)

    def process_complete(self, url, payload):
        rv = [strip_server(url)] + payload
        return True, rv

    def process_action(self, url, payload):
        action = payload["action"]
        cmd_id = payload["id"]
        self.logger.debug(f"Got action: {action}")
        try:
            action_handler = self.actions[action]
        except KeyError as e:
            raise ValueError(f"Unknown action {action}") from e
        try:
            with ActionContext(self.logger, self.protocol, payload.get("context")):
                try:
                    result = action_handler(payload)
                except AttributeError as e:
                    # If we fail to get an attribute from the protocol presumably that's a
                    # ProtocolPart we don't implement
                    # AttributeError got an obj property in Python 3.10, for older versions we
                    # fall back to looking at the error message.
                    if ((hasattr(e, "obj") and getattr(e, "obj") == self.protocol) or
                        f"'{self.protocol.__class__.__name__}' object has no attribute" in str(e)):
                        raise NotImplementedError from e
                    raise
        except self.unimplemented_exc:
            self.logger.warning("Action %s not implemented" % action)
            self._send_message(cmd_id, "complete", "error", f"Action {action} not implemented")
        except self.expected_exc:
            self.logger.debug(f"Action {action} failed with an expected exception")
            self._send_message(cmd_id, "complete", "error", f"Action {action} failed")
        except Exception:
            self.logger.warning(f"Action {action} failed")
            self._send_message(cmd_id, "complete", "error")
            raise
        else:
            self.logger.debug(f"Action {action} completed with result {result}")
            return_message = {"result": result}
            self._send_message(cmd_id, "complete", "success", json.dumps(return_message))

        return False, None

    def _send_message(self, cmd_id, message_type, status, message=None):
        self.protocol.testdriver.send_message(cmd_id, message_type, status, message=message)


class AsyncCallbackHandler(CallbackHandler):
    """
    Handle synchronous and asynchronous actions. Extends `CallbackHandler` with support of async actions.
    """

    def __init__(self, logger, protocol, test_window, loop):
        super().__init__(logger, protocol, test_window)
        self.loop = loop
        self.async_actions = {cls.name: cls(self.logger, self.protocol) for cls in async_actions}

    def process_action(self, url, payload):
        action = payload["action"]
        if action in self.async_actions:
            # Schedule async action to be processed in the event loop and return immediately.
            self.logger.debug(f"Scheduling async action processing: {action}, {payload}")
            self.loop.create_task(self._process_async_action(action, payload))
            return False, None
        else:
            # Fallback to the default action processing, which will fail if the action is not implemented.
            self.logger.debug(f"Processing synchronous action: {action}, {payload}")
            return super().process_action(url, payload)

    async def _process_async_action(self, action, payload):
        """
        Process async action and send the result back to the test driver.

        This method is analogous to `process_action` but is intended to be used with async actions in a task, so it does
        not raise unexpected exceptions. However, the unexpected exceptions are logged and the error message is sent
        back to the test driver.
        """
        async_action_handler = self.async_actions[action]
        cmd_id = payload["id"]
        try:
            result = await async_action_handler(payload)
        except AttributeError as e:
            # If we fail to get an attribute from the protocol presumably that's a
            # ProtocolPart we don't implement
            # AttributeError got an obj property in Python 3.10, for older versions we
            # fall back to looking at the error message.
            if ((hasattr(e, "obj") and getattr(e, "obj") == self.protocol) or
                    f"'{self.protocol.__class__.__name__}' object has no attribute" in str(e)):
                raise NotImplementedError from e
        except self.unimplemented_exc:
            self.logger.warning("Action %s not implemented" % action)
            self._send_message(cmd_id, "complete", "error", f"Action {action} not implemented")
        except self.expected_exc as e:
            self.logger.debug(f"Action {action} failed with an expected exception: {e}")
            self._send_message(cmd_id, "complete", "error", f"Action {action} failed: {e}")
        except Exception as e:
            self.logger.warning(f"Action {action} failed with an unexpected exception: {e}")
            self._send_message(cmd_id, "complete", "error", f"Unexpected exception: {e}")
        else:
            self.logger.debug(f"Action {action} completed with result {result}")
            return_message = {"result": result}
            self._send_message(cmd_id, "complete", "success", json.dumps(return_message))


class ActionContext:
    def __init__(self, logger, protocol, context):
        self.logger = logger
        self.protocol = protocol
        self.context = context
        self.initial_window = None

    def __enter__(self):
        if self.context is None:
            return

        self.initial_window = self.protocol.base.current_window
        self.logger.debug("Switching to window %s" % self.context)
        self.protocol.testdriver.switch_to_window(self.context, self.initial_window)

    def __exit__(self, *args):
        if self.context is None:
            return

        self.logger.debug("Switching back to initial window")
        self.protocol.base.set_window(self.initial_window)
        self.initial_window = None