chromium/third_party/wpt_tools/wpt/tools/webdriver/webdriver/bidi/client.py

# mypy: allow-untyped-defs

import asyncio
from collections import defaultdict
from typing import Any, Awaitable, Callable, List, Optional, Mapping, MutableMapping
from urllib.parse import urljoin, urlparse

from . import modules
from .error import from_error_details
from .transport import get_running_loop, Transport


class BidiSession:
    """A WebDriver BiDi session.

    This is the main representation of a BiDi session and provides the
    interface for running commands in the session, and for attaching
    event handlers to the session. For example:

    async def on_log(method, data):
        print(data)

    session = BidiSession("ws://localhost:4445", capabilities)
    remove_listener = session.add_event_listener("log.entryAdded", on_log)
    await session.start()
    await session.subscribe("log.entryAdded")

    # Do some stuff with the session

    remove_listener()
    session.end()

    If the session id is provided it's assumed that the underlying
    WebDriver session was already created, and the WebSocket URL was
    taken from the new session response. If no session id is provided, it's
    assumed that a BiDi-only session should be created when start() is called.

    It can also be used as a context manager, with the WebSocket transport
    implictly being created when the context is entered, and closed when
    the context is exited.

    :param websocket_url: WebSockets URL on which to connect to the session.
                          This excludes any path component.
    :param session_id: String id of existing HTTP session
    :param capabilities: Capabilities response of existing session
    :param requested_capabilities: Dictionary representing the capabilities request.

    """

    def __init__(self,
                 websocket_url: str,
                 session_id: Optional[str] = None,
                 capabilities: Optional[Mapping[str, Any]] = None,
                 requested_capabilities: Optional[Mapping[str, Any]] = None):
        self.transport: Optional[Transport] = None

        # The full URL for a websocket looks like
        # ws://<host>:<port>/session when we're creating a session and
        # ws://<host>:<port>/session/<sessionid> when we're connecting to an existing session.
        # To be user friendly, handle the case where the class was created with either a
        # full URL including the path, and also the case where just a server url is passed in.
        parsed_url = urlparse(websocket_url)
        if parsed_url.path == "" or parsed_url.path == "/":
            if session_id is None:
                websocket_url = urljoin(websocket_url, "session")
            else:
                websocket_url = urljoin(websocket_url, f"session/{session_id}")
        else:
            if session_id is not None:
                if parsed_url.path != f"/session/{session_id}":
                    raise ValueError(f"WebSocket URL {session_id} doesn't match session id")
            else:
                if parsed_url.path != "/session":
                    raise ValueError(f"WebSocket URL {session_id} doesn't match session url")

        if session_id is None and capabilities is not None:
            raise ValueError("Tried to create BiDi-only session with existing capabilities")

        self.websocket_url = websocket_url
        self.requested_capabilities = requested_capabilities
        self.capabilities = capabilities
        self.session_id = session_id

        self.command_id = 0
        self.pending_commands: MutableMapping[int, "asyncio.Future[Any]"] = {}
        self.event_listeners: MutableMapping[
            Optional[str],
            List[Callable[[str, Mapping[str, Any]], Any]]
        ] = defaultdict(list)

        # Modules.
        # For each module, have a property representing that module
        self.browser = modules.Browser(self)
        self.browsing_context = modules.BrowsingContext(self)
        self.input = modules.Input(self)
        self.network = modules.Network(self)
        self.permissions = modules.Permissions(self)
        self.script = modules.Script(self)
        self.session = modules.Session(self)
        self.storage = modules.Storage(self)

    @property
    def event_loop(self):
        if self.transport:
            return self.transport.loop

        return None

    @classmethod
    def from_http(cls,
                  session_id: str,
                  capabilities: Mapping[str, Any]) -> "BidiSession":
        """Create a BiDi session from an existing HTTP session

        :param session_id: String id of the session
        :param capabilities: Capabilities returned in the New Session HTTP response."""
        websocket_url = capabilities.get("webSocketUrl")
        if websocket_url is None:
            raise ValueError("No webSocketUrl found in capabilities")
        if not isinstance(websocket_url, str):
            raise ValueError("webSocketUrl is not a string")
        return cls(websocket_url, session_id=session_id, capabilities=capabilities)

    @classmethod
    def bidi_only(cls,
                  websocket_url: str,
                  requested_capabilities: Optional[Mapping[str, Any]] = None) -> "BidiSession":
        """Create a BiDi session where there is no existing HTTP session

        :param websocket_url: URL to the WebSocket server listening for BiDi connections
        :param requested_capabilities: Capabilities request for establishing the session."""
        return cls(websocket_url, requested_capabilities=requested_capabilities)

    async def __aenter__(self) -> "BidiSession":
        await self.start()
        return self

    async def __aexit__(self, *args: Any) -> None:
        await self.end()

    async def start_transport(self,
                              loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
        if self.transport is None:
            if loop is None:
                loop = get_running_loop()

            self.transport = Transport(self.websocket_url, self.on_message, loop=loop)
            await self.transport.start()

    async def start(self,
                    loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
        """Connect to the WebDriver BiDi remote via WebSockets"""

        await self.start_transport(loop)

        if self.session_id is None:
            self.session_id, self.capabilities = await self.session.new(  # type: ignore
                capabilities=self.requested_capabilities)

    async def send_command(
        self,
        method: str,
        params: Mapping[str, Any]
    ) -> Awaitable[Mapping[str, Any]]:
        """Send a command to the remote server"""
        # this isn't threadsafe
        self.command_id += 1
        command_id = self.command_id

        body = {
            "id": command_id,
            "method": method,
            "params": params
        }
        assert command_id not in self.pending_commands
        assert self.transport is not None
        self.pending_commands[command_id] = self.transport.loop.create_future()
        await self.transport.send(body)

        return self.pending_commands[command_id]

    async def on_message(self, data: Mapping[str, Any]) -> None:
        """Handle a message from the remote server"""
        if data["type"] in ["error", "success"]:
            # This is a command response or error
            future = self.pending_commands.get(data["id"])
            if future is None:
                raise ValueError(f"No pending command with id {data['id']}")
            if data["type"] == "success":
                assert isinstance(data["result"], dict)
                future.set_result(data["result"])
            else:
                assert isinstance(data["error"], str)
                assert isinstance(data["message"], str)
                exception = from_error_details(data["error"],
                                               data["message"],
                                               data.get("stacktrace"))
                future.set_exception(exception)
        elif data["type"] == "event":
            # This is an event
            assert isinstance(data["method"], str)
            assert isinstance(data["params"], dict)

            listeners = self.event_listeners.get(data["method"], [])
            if not listeners:
                listeners = self.event_listeners.get(None, [])
            for listener in listeners:
                asyncio.create_task(listener(data["method"], data["params"]))
        else:
            raise ValueError(f"Unexpected message: {data!r}")

    async def end(self) -> None:
        """Close websocket connection."""
        if self.transport is not None:
            await self.transport.end()
            self.transport = None

    def add_event_listener(
        self,
        name: Optional[str],
        fn: Callable[[str, Mapping[str, Any]], Awaitable[Any]]
    ) -> Callable[[], None]:
        """Add a listener for the event with a given name.

        If name is None, the listener is called for all messages that are not otherwise
        handled.

        :param name: Name of event to listen for or None to register a default handler
        :param fn: Async callback function that receives event data

        :return: Function to remove the added listener
        """
        self.event_listeners[name].append(fn)

        return lambda: self.event_listeners[name].remove(fn)