"""
Test suite for socketserver.
"""
import contextlib
import io
import os
import select
import signal
import socket
import threading
import unittest
import socketserver
import test.support
from test.support import reap_children, verbose
from test.support import os_helper
from test.support import socket_helper
from test.support import threading_helper
test.support.requires("network")
test.support.requires_working_socket(module=True)
TEST_STR = b"hello world\n"
HOST = socket_helper.HOST
HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
'requires Unix sockets')
HAVE_FORKING = test.support.has_fork_support
requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
# Remember real select() to avoid interferences with mocking
_real_select = select.select
def receive(sock, n, timeout=test.support.SHORT_TIMEOUT):
r, w, x = _real_select([sock], [], [], timeout)
if sock in r:
return sock.recv(n)
else:
raise RuntimeError("timed out on %r" % (sock,))
@test.support.requires_fork()
@contextlib.contextmanager
def simple_subprocess(testcase):
"""Tests that a custom child process is not waited on (Issue 1540386)"""
pid = os.fork()
if pid == 0:
# Don't raise an exception; it would be caught by the test harness.
os._exit(72)
try:
yield None
except:
raise
finally:
test.support.wait_process(pid, exitcode=72)
class SocketServerTest(unittest.TestCase):
"""Test all socket servers."""
def setUp(self):
self.port_seed = 0
self.test_files = []
def tearDown(self):
reap_children()
for fn in self.test_files:
try:
os.remove(fn)
except OSError:
pass
self.test_files[:] = []
def pickaddr(self, proto):
if proto == socket.AF_INET:
return (HOST, 0)
else:
# XXX: We need a way to tell AF_UNIX to pick its own name
# like AF_INET provides port==0.
fn = socket_helper.create_unix_domain_name()
self.test_files.append(fn)
return fn
def make_server(self, addr, svrcls, hdlrbase):
class MyServer(svrcls):
def handle_error(self, request, client_address):
self.close_request(request)
raise
class MyHandler(hdlrbase):
def handle(self):
line = self.rfile.readline()
self.wfile.write(line)
if verbose: print("creating server")
try:
server = MyServer(addr, MyHandler)
except PermissionError as e:
# Issue 29184: cannot bind() a Unix socket on Android.
self.skipTest('Cannot create server (%s, %s): %s' %
(svrcls, addr, e))
self.assertEqual(server.server_address, server.socket.getsockname())
return server
@threading_helper.reap_threads
def run_server(self, svrcls, hdlrbase, testfunc):
server = self.make_server(self.pickaddr(svrcls.address_family),
svrcls, hdlrbase)
# We had the OS pick a port, so pull the real address out of
# the server.
addr = server.server_address
if verbose:
print("ADDR =", addr)
print("CLASS =", svrcls)
t = threading.Thread(
name='%s serving' % svrcls,
target=server.serve_forever,
# Short poll interval to make the test finish quickly.
# Time between requests is short enough that we won't wake
# up spuriously too many times.
kwargs={'poll_interval':0.01})
t.daemon = True # In case this function raises.
t.start()
if verbose: print("server running")
for i in range(3):
if verbose: print("test client", i)
testfunc(svrcls.address_family, addr)
if verbose: print("waiting for server")
server.shutdown()
t.join()
server.server_close()
self.assertEqual(-1, server.socket.fileno())
if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
# bpo-31151: Check that ForkingMixIn.server_close() waits until
# all children completed
self.assertFalse(server.active_children)
if verbose: print("done")
def stream_examine(self, proto, addr):
with socket.socket(proto, socket.SOCK_STREAM) as s:
s.connect(addr)
s.sendall(TEST_STR)
buf = data = receive(s, 100)
while data and b'\n' not in buf:
data = receive(s, 100)
buf += data
self.assertEqual(buf, TEST_STR)
def dgram_examine(self, proto, addr):
with socket.socket(proto, socket.SOCK_DGRAM) as s:
if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
s.bind(self.pickaddr(proto))
s.sendto(TEST_STR, addr)
buf = data = receive(s, 100)
while data and b'\n' not in buf:
data = receive(s, 100)
buf += data
self.assertEqual(buf, TEST_STR)
def test_TCPServer(self):
self.run_server(socketserver.TCPServer,
socketserver.StreamRequestHandler,
self.stream_examine)
def test_ThreadingTCPServer(self):
self.run_server(socketserver.ThreadingTCPServer,
socketserver.StreamRequestHandler,
self.stream_examine)
@requires_forking
def test_ForkingTCPServer(self):
with simple_subprocess(self):
self.run_server(socketserver.ForkingTCPServer,
socketserver.StreamRequestHandler,
self.stream_examine)
@requires_unix_sockets
def test_UnixStreamServer(self):
self.run_server(socketserver.UnixStreamServer,
socketserver.StreamRequestHandler,
self.stream_examine)
@requires_unix_sockets
def test_ThreadingUnixStreamServer(self):
self.run_server(socketserver.ThreadingUnixStreamServer,
socketserver.StreamRequestHandler,
self.stream_examine)
@requires_unix_sockets
@requires_forking
def test_ForkingUnixStreamServer(self):
with simple_subprocess(self):
self.run_server(socketserver.ForkingUnixStreamServer,
socketserver.StreamRequestHandler,
self.stream_examine)
def test_UDPServer(self):
self.run_server(socketserver.UDPServer,
socketserver.DatagramRequestHandler,
self.dgram_examine)
def test_ThreadingUDPServer(self):
self.run_server(socketserver.ThreadingUDPServer,
socketserver.DatagramRequestHandler,
self.dgram_examine)
@requires_forking
def test_ForkingUDPServer(self):
with simple_subprocess(self):
self.run_server(socketserver.ForkingUDPServer,
socketserver.DatagramRequestHandler,
self.dgram_examine)
@requires_unix_sockets
def test_UnixDatagramServer(self):
self.run_server(socketserver.UnixDatagramServer,
socketserver.DatagramRequestHandler,
self.dgram_examine)
@requires_unix_sockets
def test_ThreadingUnixDatagramServer(self):
self.run_server(socketserver.ThreadingUnixDatagramServer,
socketserver.DatagramRequestHandler,
self.dgram_examine)
@requires_unix_sockets
@requires_forking
def test_ForkingUnixDatagramServer(self):
self.run_server(socketserver.ForkingUnixDatagramServer,
socketserver.DatagramRequestHandler,
self.dgram_examine)
@threading_helper.reap_threads
def test_shutdown(self):
# Issue #2302: shutdown() should always succeed in making an
# other thread leave serve_forever().
class MyServer(socketserver.TCPServer):
pass
class MyHandler(socketserver.StreamRequestHandler):
pass
threads = []
for i in range(20):
s = MyServer((HOST, 0), MyHandler)
t = threading.Thread(
name='MyServer serving',
target=s.serve_forever,
kwargs={'poll_interval':0.01})
t.daemon = True # In case this function raises.
threads.append((t, s))
for t, s in threads:
t.start()
s.shutdown()
for t, s in threads:
t.join()
s.server_close()
def test_close_immediately(self):
class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass
server = MyServer((HOST, 0), lambda: None)
server.server_close()
def test_tcpserver_bind_leak(self):
# Issue #22435: the server socket wouldn't be closed if bind()/listen()
# failed.
# Create many servers for which bind() will fail, to see if this result
# in FD exhaustion.
for i in range(1024):
with self.assertRaises(OverflowError):
socketserver.TCPServer((HOST, -1),
socketserver.StreamRequestHandler)
def test_context_manager(self):
with socketserver.TCPServer((HOST, 0),
socketserver.StreamRequestHandler) as server:
pass
self.assertEqual(-1, server.socket.fileno())
class ErrorHandlerTest(unittest.TestCase):
"""Test that the servers pass normal exceptions from the handler to
handle_error(), and that exiting exceptions like SystemExit and
KeyboardInterrupt are not passed."""
def tearDown(self):
os_helper.unlink(os_helper.TESTFN)
def test_sync_handled(self):
BaseErrorTestServer(ValueError)
self.check_result(handled=True)
def test_sync_not_handled(self):
with self.assertRaises(SystemExit):
BaseErrorTestServer(SystemExit)
self.check_result(handled=False)
def test_threading_handled(self):
ThreadingErrorTestServer(ValueError)
self.check_result(handled=True)
def test_threading_not_handled(self):
with threading_helper.catch_threading_exception() as cm:
ThreadingErrorTestServer(SystemExit)
self.check_result(handled=False)
self.assertIs(cm.exc_type, SystemExit)
@requires_forking
def test_forking_handled(self):
ForkingErrorTestServer(ValueError)
self.check_result(handled=True)
@requires_forking
def test_forking_not_handled(self):
ForkingErrorTestServer(SystemExit)
self.check_result(handled=False)
def check_result(self, handled):
with open(os_helper.TESTFN) as log:
expected = 'Handler called\n' + 'Error handled\n' * handled
self.assertEqual(log.read(), expected)
class BaseErrorTestServer(socketserver.TCPServer):
def __init__(self, exception):
self.exception = exception
super().__init__((HOST, 0), BadHandler)
with socket.create_connection(self.server_address):
pass
try:
self.handle_request()
finally:
self.server_close()
self.wait_done()
def handle_error(self, request, client_address):
with open(os_helper.TESTFN, 'a') as log:
log.write('Error handled\n')
def wait_done(self):
pass
class BadHandler(socketserver.BaseRequestHandler):
def handle(self):
with open(os_helper.TESTFN, 'a') as log:
log.write('Handler called\n')
raise self.server.exception('Test error')
class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
BaseErrorTestServer):
def __init__(self, *pos, **kw):
self.done = threading.Event()
super().__init__(*pos, **kw)
def shutdown_request(self, *pos, **kw):
super().shutdown_request(*pos, **kw)
self.done.set()
def wait_done(self):
self.done.wait()
if HAVE_FORKING:
class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
pass
class SocketWriterTest(unittest.TestCase):
def test_basics(self):
class Handler(socketserver.StreamRequestHandler):
def handle(self):
self.server.wfile = self.wfile
self.server.wfile_fileno = self.wfile.fileno()
self.server.request_fileno = self.request.fileno()
server = socketserver.TCPServer((HOST, 0), Handler)
self.addCleanup(server.server_close)
s = socket.socket(
server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
with s:
s.connect(server.server_address)
server.handle_request()
self.assertIsInstance(server.wfile, io.BufferedIOBase)
self.assertEqual(server.wfile_fileno, server.request_fileno)
def test_write(self):
# Test that wfile.write() sends data immediately, and that it does
# not truncate sends when interrupted by a Unix signal
pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
class Handler(socketserver.StreamRequestHandler):
def handle(self):
self.server.sent1 = self.wfile.write(b'write data\n')
# Should be sent immediately, without requiring flush()
self.server.received = self.rfile.readline()
big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
self.server.sent2 = self.wfile.write(big_chunk)
server = socketserver.TCPServer((HOST, 0), Handler)
self.addCleanup(server.server_close)
interrupted = threading.Event()
def signal_handler(signum, frame):
interrupted.set()
original = signal.signal(signal.SIGUSR1, signal_handler)
self.addCleanup(signal.signal, signal.SIGUSR1, original)
response1 = None
received2 = None
main_thread = threading.get_ident()
def run_client():
s = socket.socket(server.address_family, socket.SOCK_STREAM,
socket.IPPROTO_TCP)
with s, s.makefile('rb') as reader:
s.connect(server.server_address)
nonlocal response1
response1 = reader.readline()
s.sendall(b'client response\n')
reader.read(100)
# The main thread should now be blocking in a send() syscall.
# But in theory, it could get interrupted by other signals,
# and then retried. So keep sending the signal in a loop, in
# case an earlier signal happens to be delivered at an
# inconvenient moment.
while True:
pthread_kill(main_thread, signal.SIGUSR1)
if interrupted.wait(timeout=float(1)):
break
nonlocal received2
received2 = len(reader.read())
background = threading.Thread(target=run_client)
background.start()
server.handle_request()
background.join()
self.assertEqual(server.sent1, len(response1))
self.assertEqual(response1, b'write data\n')
self.assertEqual(server.received, b'client response\n')
self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
class MiscTestCase(unittest.TestCase):
def test_all(self):
# objects defined in the module should be in __all__
expected = []
for name in dir(socketserver):
if not name.startswith('_'):
mod_object = getattr(socketserver, name)
if getattr(mod_object, '__module__', None) == 'socketserver':
expected.append(name)
self.assertCountEqual(socketserver.__all__, expected)
def test_shutdown_request_called_if_verify_request_false(self):
# Issue #26309: BaseServer should call shutdown_request even if
# verify_request is False
class MyServer(socketserver.TCPServer):
def verify_request(self, request, client_address):
return False
shutdown_called = 0
def shutdown_request(self, request):
self.shutdown_called += 1
socketserver.TCPServer.shutdown_request(self, request)
server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
s = socket.socket(server.address_family, socket.SOCK_STREAM)
s.connect(server.server_address)
s.close()
server.handle_request()
self.assertEqual(server.shutdown_called, 1)
server.server_close()
def test_threads_reaped(self):
"""
In #37193, users reported a memory leak
due to the saving of every request thread. Ensure that
not all threads are kept forever.
"""
class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass
server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
for n in range(10):
with socket.create_connection(server.server_address):
server.handle_request()
self.assertLess(len(server._threads), 10)
server.server_close()
if __name__ == "__main__":
unittest.main()