import asyncio
import json
import logging
import sys
from typing import Any, Callable, Coroutine, List, Optional, Mapping
import websockets
from websockets.exceptions import ConnectionClosed
logger = logging.getLogger("webdriver.bidi")
def get_running_loop() -> asyncio.AbstractEventLoop:
if sys.version_info >= (3, 7):
return asyncio.get_running_loop()
else:
# Unlike the above, this will actually create an event loop
# if there isn't one; hopefully running tests in Python >= 3.7
# will allow us to catch any behaviour difference
# (Needs to be in else for mypy to believe this is reachable)
return asyncio.get_event_loop()
class Transport:
"""Low level message handler for the WebSockets connection"""
def __init__(self, url: str,
msg_handler: Callable[[Mapping[str, Any]], Coroutine[Any, Any, None]],
loop: Optional[asyncio.AbstractEventLoop] = None):
self.url = url
self.connection: Optional[websockets.WebSocketClientProtocol] = None # type: ignore
self.msg_handler = msg_handler
self.send_buf: List[Mapping[str, Any]] = []
if loop is None:
loop = get_running_loop()
self.loop = loop
self.read_message_task: Optional[asyncio.Task[Any]] = None
async def start(self) -> None:
self.connection = await websockets.connect(self.url) # type: ignore
self.read_message_task = self.loop.create_task(self.read_messages())
for msg in self.send_buf:
await self._send(self.connection, msg)
async def send(self, data: Mapping[str, Any]) -> None:
if self.connection is not None:
await self._send(self.connection, data)
else:
self.send_buf.append(data)
@staticmethod
async def _send(
connection: websockets.WebSocketClientProtocol, # type: ignore
data: Mapping[str, Any]
) -> None:
msg = json.dumps(data)
logger.debug("→ %s", msg)
await connection.send(msg)
async def handle(self, msg: str) -> None:
logger.debug("← %s", msg)
data = json.loads(msg)
await self.msg_handler(data)
async def end(self) -> None:
if self.connection:
await self.connection.close()
self.connection = None
async def read_messages(self) -> None:
assert self.connection is not None
try:
async for msg in self.connection:
if not isinstance(msg, str):
raise ValueError("Got a binary message")
await self.handle(msg)
except ConnectionClosed:
logger.debug("connection closed while reading messages")
async def wait_closed(self) -> None:
if self.connection and not self.connection.closed:
await self.connection.wait_closed()