chromium/third_party/wpt_tools/wpt/tools/webdriver/webdriver/bidi/transport.py

import asyncio
import json
import logging
import sys
from typing import Any, Callable, Coroutine, List, Optional, Mapping

import websockets

from websockets.exceptions import ConnectionClosed

logger = logging.getLogger("webdriver.bidi")


def get_running_loop() -> asyncio.AbstractEventLoop:
    if sys.version_info >= (3, 7):
        return asyncio.get_running_loop()
    else:
        # Unlike the above, this will actually create an event loop
        # if there isn't one; hopefully running tests in Python >= 3.7
        # will allow us to catch any behaviour difference
        # (Needs to be in else for mypy to believe this is reachable)
        return asyncio.get_event_loop()


class Transport:
    """Low level message handler for the WebSockets connection"""
    def __init__(self, url: str,
                 msg_handler: Callable[[Mapping[str, Any]], Coroutine[Any, Any, None]],
                 loop: Optional[asyncio.AbstractEventLoop] = None):
        self.url = url
        self.connection: Optional[websockets.WebSocketClientProtocol] = None  # type: ignore
        self.msg_handler = msg_handler
        self.send_buf: List[Mapping[str, Any]] = []

        if loop is None:
            loop = get_running_loop()
        self.loop = loop

        self.read_message_task: Optional[asyncio.Task[Any]] = None

    async def start(self) -> None:
        self.connection = await websockets.connect(self.url)  # type: ignore
        self.read_message_task = self.loop.create_task(self.read_messages())

        for msg in self.send_buf:
            await self._send(self.connection, msg)

    async def send(self, data: Mapping[str, Any]) -> None:
        if self.connection is not None:
            await self._send(self.connection, data)
        else:
            self.send_buf.append(data)

    @staticmethod
    async def _send(
        connection: websockets.WebSocketClientProtocol,  # type: ignore
        data: Mapping[str, Any]
    ) -> None:
        msg = json.dumps(data)
        logger.debug("→ %s", msg)
        await connection.send(msg)

    async def handle(self, msg: str) -> None:
        logger.debug("← %s", msg)
        data = json.loads(msg)
        await self.msg_handler(data)

    async def end(self) -> None:
        if self.connection:
            await self.connection.close()
            self.connection = None

    async def read_messages(self) -> None:
        assert self.connection is not None
        try:
            async for msg in self.connection:
                if not isinstance(msg, str):
                    raise ValueError("Got a binary message")
                await self.handle(msg)
        except ConnectionClosed:
            logger.debug("connection closed while reading messages")

    async def wait_closed(self) -> None:
        if self.connection and not self.connection.closed:
            await self.connection.wait_closed()