chromium/content/test/gpu/gpu_tests/util/websocket_server.py

# Copyright 2023 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Code to allow tests to communicate via a websocket server."""

import logging
import threading
from typing import Optional

import websockets  # pylint: disable=import-error
import websockets.sync.server as sync_server  # pylint: disable=import-error

WEBSOCKET_PORT_TIMEOUT_SECONDS = 10
WEBSOCKET_SETUP_TIMEOUT_SECONDS = 5
WEBSOCKET_CLOSE_TIMEOUT_SECONDS = 2
SERVER_SHUTDOWN_TIMEOUT_SECONDS = 5

# The client (Chrome) should never be closing the connection. If it does, it's
# indicative of something going wrong like a renderer crash.
ClientClosedConnectionError = websockets.exceptions.ConnectionClosedOK

# Alias for readability.
WebsocketReceiveMessageTimeoutError = TimeoutError


class WebsocketServer():

  def __init__(self):
    """Server that abstracts the websocket library under the hood.

    Only supports one active connection at a time.
    """
    self.server_port = None
    self.websocket = None
    self.connection_stopper_event = None
    self.connection_closed_event = None
    self.port_set_event = threading.Event()
    self.connection_received_event = threading.Event()
    self._server_thread = None

  def StartServer(self) -> None:
    """Starts the websocket server on a separate thread."""
    assert self._server_thread is None, 'Server already running'
    self._server_thread = _ServerThread(self)
    self._server_thread.daemon = True
    self._server_thread.start()
    got_port = self.port_set_event.wait(WEBSOCKET_PORT_TIMEOUT_SECONDS)
    if not got_port:
      raise RuntimeError('Websocket server did not provide a port')
    # Note: We don't need to set up any port forwarding for remote platforms
    # after this point due to Telemetry's use of --proxy-server to send all
    # traffic through the TsProxyServer. This causes network traffic to pop out
    # on the host, which means that using the websocket server's port directly
    # works.

  def ClearCurrentConnection(self) -> None:
    if self.connection_stopper_event:
      self.connection_stopper_event.set()
      closed = self.connection_closed_event.wait(
          WEBSOCKET_CLOSE_TIMEOUT_SECONDS)
      if not closed:
        raise RuntimeError('Websocket connection did not close')
    self.connection_stopper_event = None
    self.connection_closed_event = None
    self.websocket = None
    self.connection_received_event.clear()

  def WaitForConnection(self, timeout: Optional[float] = None) -> None:
    if self.websocket:
      return
    timeout = timeout or WEBSOCKET_SETUP_TIMEOUT_SECONDS
    self.connection_received_event.wait(timeout)
    if not self.websocket:
      raise RuntimeError('Websocket connection was not established')

  def StopServer(self) -> None:
    self.ClearCurrentConnection()
    self._server_thread.shutdown()
    self._server_thread.join(SERVER_SHUTDOWN_TIMEOUT_SECONDS)
    if self._server_thread.is_alive():
      logging.error(
          'Websocket server did not shut down properly - this might be '
          'indicative of an issue in the test harness')

  def Send(self, message: str) -> None:
    self.websocket.send(message)

  def Receive(self, timeout: int) -> str:
    try:
      return self.websocket.recv(timeout)
    except TimeoutError as e:
      raise WebsocketReceiveMessageTimeoutError(
          'Timed out after %d seconds waiting for websocket message' %
          timeout) from e


class _ServerThread(threading.Thread):
  def __init__(self, server_instance: WebsocketServer, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self._server_instance = server_instance
    self.websocket_server = None

  def run(self) -> None:
    StartWebsocketServer(self, self._server_instance)

  def shutdown(self) -> None:
    self.websocket_server.shutdown()


def StartWebsocketServer(server_thread: _ServerThread,
                         server_instance: WebsocketServer) -> None:
  def HandleWebsocketConnection(
      websocket: sync_server.ServerConnection) -> None:
    # We only allow one active connection - if there are multiple, something is
    # wrong.
    assert server_instance.connection_stopper_event is None
    assert server_instance.connection_closed_event is None
    assert server_instance.websocket is None
    server_instance.connection_stopper_event = threading.Event()
    server_instance.connection_closed_event = threading.Event()
    # Keep our own reference in case the server clears its reference before the
    # await finishes.
    connection_stopper_event = server_instance.connection_stopper_event
    connection_closed_event = server_instance.connection_closed_event
    server_instance.websocket = websocket
    server_instance.connection_received_event.set()
    connection_stopper_event.wait()
    connection_closed_event.set()

  with sync_server.serve(HandleWebsocketConnection, '127.0.0.1', 0) as server:
    server_thread.websocket_server = server
    server_instance.server_port = server.socket.getsockname()[1]
    server_instance.port_set_event.set()
    server.serve_forever()