chromium/chrome/test/chromedriver/test/webserver.py

# Copyright 2013 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import http.server
import os
import socketserver
import ssl
import sys
import threading


class Responder(object):
  """Sends a HTTP response. Used with TestWebServer."""

  def __init__(self, handler):
    self._handler = handler

  def SendResponse(self, headers, body):
    """Sends OK response with body."""
    self.SendHeaders(headers, len(body))
    self.SendBody(body)

  def SendResponseFromFile(self, path):
    """Sends OK response with the given file as the body."""
    headers = {}
    if path.endswith(".json"):
      headers["content-type"] = "application/json"
    with open(path, 'rb') as f:
      self.SendResponse(headers, f.read())

  def SendHeaders(self, headers={}, content_length=None):
    """Sends headers for OK response."""
    self._handler.send_response(200)
    for field, value in headers.items():
      self._handler.send_header(field, value)
    if content_length:
      self._handler.send_header('Content-Length', content_length)
    self._handler.end_headers()

  def SendError(self, code):
    """Sends response for the given HTTP error code."""
    self._handler.send_error(code)

  def SendBody(self, body):
    """Just sends the body, no headers."""
    self._handler.wfile.write(body)


class Request(object):
  """An HTTP request."""

  def __init__(self, handler):
    self._handler = handler

  def GetPath(self):
    return self._handler.path

  def GetHeader(self, name):
    return self._handler.headers.get(name)


class _BaseServer(http.server.HTTPServer):
  """Internal server that throws if timed out waiting for a request."""

  def __init__(self, on_request, server_cert_and_key_path=None):
    """Starts the server.

    It is an HTTP server if parameter server_cert_and_key_path is not provided.
    Otherwise, it is an HTTPS server.

    Args:
      server_cert_and_key_path: path to a PEM file containing the cert and key.
                                if it is None, start the server as an HTTP one.
    """
    class _Handler(http.server.BaseHTTPRequestHandler):
      """Internal handler that just asks the server to handle the request."""

      def do_GET(self):
        if self.path.endswith('favicon.ico'):
          self.send_error(404)
          return
        on_request(Request(self), Responder(self))

      def do_POST(self):
        on_request(Request(self), Responder(self))

      def log_message(self, *args, **kwargs):
        """Overriddes base class method to disable logging."""
        pass

      def handle(self):
        try:
          http.server.BaseHTTPRequestHandler.handle(self)
        except:
          pass # Ignore socket errors.

      def finish(self):
        try:
          http.server.BaseHTTPRequestHandler.finish(self)
        except:
          pass # Ignore socket errors.

    http.server.HTTPServer.__init__(self, ('127.0.0.1', 0), \
      _Handler)

    if server_cert_and_key_path is not None:
      self._is_https_enabled = True
      self.socket = ssl.wrap_socket(
          self.socket, certfile=server_cert_and_key_path, server_side=True)
    else:
      self._is_https_enabled = False

  def handle_timeout(self):
    """Overridden from SocketServer."""
    raise RuntimeError('Timed out waiting for http request')

  def GetUrl(self, host=None):
    """Returns the base URL of the server."""
    postfix = '://%s:%s' % (host or '127.0.0.1', self.server_port)
    if self._is_https_enabled:
      return 'https' + postfix
    return 'http' + postfix


class _ThreadingServer(socketserver.ThreadingMixIn, _BaseServer):
  """_BaseServer enhanced to handle multiple requests simultaneously"""
  pass


class WebServer(object):
  """An HTTP or HTTPS server that serves on its own thread.

  Serves files from given directory but may use custom data for specific paths.
  """

  def __init__(self, root_dir, server_cert_and_key_path=None):
    """Starts the server.

    It is an HTTP server if parameter server_cert_and_key_path is not provided.
    Otherwise, it is an HTTPS server.

    Args:
      root_dir: root path to serve files from. This parameter is required.
      server_cert_and_key_path: path to a PEM file containing the cert and key.
                                if it is None, start the server as an HTTP one.
    """
    self._root_dir = os.path.abspath(root_dir)
    self._server = _ThreadingServer(self._OnRequest, server_cert_and_key_path)
    self._thread = threading.Thread(target=self._server.serve_forever)
    self._thread.daemon = True
    self._thread.start()
    self._path_data_map = {}
    self._path_callback_map = {}
    self._path_maps_lock = threading.Lock()

  def _OnRequest(self, request, responder):
    path = request.GetPath().split('?')[0]

    # Serve from path -> callback and data maps.
    self._path_maps_lock.acquire()
    try:
      if path in self._path_callback_map:
        headers, body = self._path_callback_map[path](request)
        if body:
          responder.SendResponse(headers, body)
        else:
          responder.SendError(503)
        return

      if path in self._path_data_map:
        responder.SendResponse({}, self._path_data_map[path])
        return
    finally:
      self._path_maps_lock.release()

    # Serve from file.
    path = os.path.normpath(
        os.path.join(self._root_dir, *path.split('/')))
    if not path.startswith(self._root_dir):
      responder.SendError(403)
      return
    if not os.path.exists(path):
      responder.SendError(404)
      return
    responder.SendResponseFromFile(path)

  def SetDataForPath(self, path, data):
    self._path_maps_lock.acquire()
    try:
      self._path_data_map[path] = data
    finally:
      self._path_maps_lock.release()

  def SetCallbackForPath(self, path, func):
    self._path_maps_lock.acquire()
    try:
      if func is None:
        del self._path_callback_map[path]
      else:
        self._path_callback_map[path] = func
    finally:
      self._path_maps_lock.release()


  def GetUrl(self, host=None):
    """Returns the base URL of the server."""
    return self._server.GetUrl(host)

  def Shutdown(self):
    """Shuts down the server synchronously."""
    self._server.shutdown()
    self._thread.join()


class SyncWebServer(object):
  """WebServer for testing.

  Incoming requests are blocked until explicitly handled.
  This was designed for single thread use. All requests should be handled on
  the same thread.
  """

  def __init__(self):
    self._server = _BaseServer(self._OnRequest)
    # Recognized by SocketServer.
    self._server.timeout = 10
    self._on_request = None

  def _OnRequest(self, request, responder):
    self._on_request(responder)
    self._on_request = None

  def Respond(self, on_request):
    """Blocks until request comes in, then calls given handler function.

    Args:
      on_request: Function that handles the request. Invoked with single
          parameter, an instance of Responder.
    """
    if self._on_request:
      raise RuntimeError('Must handle 1 request at a time.')

    self._on_request = on_request
    while self._on_request:
      # Don't use handle_one_request, because it won't work with the timeout.
      self._server.handle_request()

  def RespondWithContent(self, content):
    """Blocks until request comes in, then handles it with the given content."""
    def SendContent(responder):
      responder.SendResponse({}, content)
    self.Respond(SendContent)

  def GetUrl(self, host=None):
    return self._server.GetUrl(host)