linux/tools/testing/selftests/net/rds/test.py

#! /usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0

import argparse
import ctypes
import errno
import hashlib
import os
import select
import signal
import socket
import subprocess
import sys
import atexit
from pwd import getpwuid
from os import stat
from lib.py import ip


libc = ctypes.cdll.LoadLibrary('libc.so.6')
setns = libc.setns

net0 = 'net0'
net1 = 'net1'

veth0 = 'veth0'
veth1 = 'veth1'

# Helper function for creating a socket inside a network namespace.
# We need this because otherwise RDS will detect that the two TCP
# sockets are on the same interface and use the loop transport instead
# of the TCP transport.
def netns_socket(netns, *args):
    u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)

    child = os.fork()
    if child == 0:
        # change network namespace
        with open(f'/var/run/netns/{netns}') as f:
            try:
                ret = setns(f.fileno(), 0)
            except IOError as e:
                print(e.errno)
                print(e)

        # create socket in target namespace
        s = socket.socket(*args)

        # send resulting socket to parent
        socket.send_fds(u0, [], [s.fileno()])

        sys.exit(0)

    # receive socket from child
    _, s, _, _ = socket.recv_fds(u1, 0, 1)
    os.waitpid(child, 0)
    u0.close()
    u1.close()
    return socket.fromfd(s[0], *args)

def signal_handler(sig, frame):
    print('Test timed out')
    sys.exit(1)

#Parse out command line arguments.  We take an optional
# timeout parameter and an optional log output folder
parser = argparse.ArgumentParser(description="init script args",
                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-d", "--logdir", action="store",
                    help="directory to store logs", default="/tmp")
parser.add_argument('--timeout', help="timeout to terminate hung test",
                    type=int, default=0)
parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
                    type=int, default=0)
parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
                    type=int, default=0)
parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
                    type=int, default=0)
args = parser.parse_args()
logdir=args.logdir
packet_loss=str(args.loss)+'%'
packet_corruption=str(args.corruption)+'%'
packet_duplicate=str(args.duplicate)+'%'

ip(f"netns add {net0}")
ip(f"netns add {net1}")
ip(f"link add type veth")

addrs = [
    # we technically don't need different port numbers, but this will
    # help identify traffic in the network analyzer
    ('10.0.0.1', 10000),
    ('10.0.0.2', 20000),
]

# move interfaces to separate namespaces so they can no longer be
# bound directly; this prevents rds from switching over from the tcp
# transport to the loop transport.
ip(f"link set {veth0} netns {net0} up")
ip(f"link set {veth1} netns {net1} up")



# add addresses
ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}")
ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}")

# add routes
ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}")
ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}")

# sanity check that our two interfaces/addresses are correctly set up
# and communicating by doing a single ping
ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}")

# Start a packet capture on each network
for net in [net0, net1]:
    tcpdump_pid = os.fork()
    if tcpdump_pid == 0:
        pcap = logdir+'/'+net+'.pcap'
        subprocess.check_call(['touch', pcap])
        user = getpwuid(stat(pcap).st_uid).pw_name
        ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}")
        sys.exit(0)

# simulate packet loss, duplication and corruption
for net, iface in [(net0, veth0), (net1, veth1)]:
    ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem  \
         corrupt {packet_corruption} loss {packet_loss} duplicate  \
         {packet_duplicate}")

# add a timeout
if args.timeout > 0:
    signal.alarm(args.timeout)
    signal.signal(signal.SIGALRM, signal_handler)

sockets = [
    netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET),
    netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET),
]

for s, addr in zip(sockets, addrs):
    s.bind(addr)
    s.setblocking(0)

fileno_to_socket = {
    s.fileno(): s for s in sockets
}

addr_to_socket = {
    addr: s for addr, s in zip(addrs, sockets)
}

socket_to_addr = {
    s: addr for addr, s in zip(addrs, sockets)
}

send_hashes = {}
recv_hashes = {}

ep = select.epoll()

for s in sockets:
    ep.register(s, select.EPOLLRDNORM)

n = 50000
nr_send = 0
nr_recv = 0

while nr_send < n:
    # Send as much as we can without blocking
    print("sending...", nr_send, nr_recv)
    while nr_send < n:
        send_data = hashlib.sha256(
            f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8')

        # pseudo-random send/receive pattern
        sender = sockets[nr_send % 2]
        receiver = sockets[1 - (nr_send % 3) % 2]

        try:
            sender.sendto(send_data, socket_to_addr[receiver])
            send_hashes.setdefault((sender.fileno(), receiver.fileno()),
                    hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8'))
            nr_send = nr_send + 1
        except BlockingIOError as e:
            break
        except OSError as e:
            if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]:
                break
            raise

    # Receive as much as we can without blocking
    print("receiving...", nr_send, nr_recv)
    while nr_recv < nr_send:
        for fileno, eventmask in ep.poll():
            receiver = fileno_to_socket[fileno]

            if eventmask & select.EPOLLRDNORM:
                while True:
                    try:
                        recv_data, address = receiver.recvfrom(1024)
                        sender = addr_to_socket[address]
                        recv_hashes.setdefault((sender.fileno(),
                            receiver.fileno()), hashlib.sha256()).update(
                                    f'<{recv_data}>'.encode('utf-8'))
                        nr_recv = nr_recv + 1
                    except BlockingIOError as e:
                        break

    # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
    for net in [net0, net1]:
        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")

print("done", nr_send, nr_recv)

# the Python socket module doesn't know these
RDS_INFO_FIRST = 10000
RDS_INFO_LAST = 10017

nr_success = 0
nr_error = 0

for s in sockets:
    for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1):
        # Sigh, the Python socket module doesn't allow us to pass
        # buffer lengths greater than 1024 for some reason. RDS
        # wants multiple pages.
        try:
            s.getsockopt(socket.SOL_RDS, optname, 1024)
            nr_success = nr_success + 1
        except OSError as e:
            nr_error = nr_error + 1
            if e.errno == errno.ENOSPC:
                # ignore
                pass

print(f"getsockopt(): {nr_success}/{nr_error}")

print("Stopping network packet captures")
subprocess.check_call(['killall', '-q', 'tcpdump'])

# We're done sending and receiving stuff, now let's check if what
# we received is what we sent.
for (sender, receiver), send_hash in send_hashes.items():
    recv_hash = recv_hashes.get((sender, receiver))

    if recv_hash is None:
        print("FAIL: No data received")
        sys.exit(1)

    if send_hash.hexdigest() != recv_hash.hexdigest():
        print("FAIL: Send/recv mismatch")
        print("hash expected:", send_hash.hexdigest())
        print("hash received:", recv_hash.hexdigest())
        sys.exit(1)

    print(f"{sender}/{receiver}: ok")

print("Success")
sys.exit(0)