linux/tools/testing/selftests/net/tcp_ao/lib/proc.c

// SPDX-License-Identifier: GPL-2.0
#include <inttypes.h>
#include <pthread.h>
#include <stdio.h>
#include "../../../../../include/linux/compiler.h"
#include "../../../../../include/linux/kernel.h"
#include "aolib.h"

struct netstat_counter {
	uint64_t val;
	char *name;
};

struct netstat {
	char *header_name;
	struct netstat *next;
	size_t counters_nr;
	struct netstat_counter *counters;
};

static struct netstat *lookup_type(struct netstat *ns,
		const char *type, size_t len)
{
	while (ns != NULL) {
		size_t cmp = max(len, strlen(ns->header_name));

		if (!strncmp(ns->header_name, type, cmp))
			return ns;
		ns = ns->next;
	}
	return NULL;
}

static struct netstat *lookup_get(struct netstat *ns,
				  const char *type, const size_t len)
{
	struct netstat *ret;

	ret = lookup_type(ns, type, len);
	if (ret != NULL)
		return ret;

	ret = malloc(sizeof(struct netstat));
	if (!ret)
		test_error("malloc()");

	ret->header_name = strndup(type, len);
	if (ret->header_name == NULL)
		test_error("strndup()");
	ret->next = ns;
	ret->counters_nr = 0;
	ret->counters = NULL;

	return ret;
}

static struct netstat *lookup_get_column(struct netstat *ns, const char *line)
{
	char *column;

	column = strchr(line, ':');
	if (!column)
		test_error("can't parse netstat file");

	return lookup_get(ns, line, column - line);
}

static void netstat_read_type(FILE *fnetstat, struct netstat **dest, char *line)
{
	struct netstat *type = lookup_get_column(*dest, line);
	const char *pos = line;
	size_t i, nr_elems = 0;
	char tmp;

	while ((pos = strchr(pos, ' '))) {
		nr_elems++;
		pos++;
	}

	*dest = type;
	type->counters = reallocarray(type->counters,
				type->counters_nr + nr_elems,
				sizeof(struct netstat_counter));
	if (!type->counters)
		test_error("reallocarray()");

	pos = strchr(line, ' ') + 1;

	if (fscanf(fnetstat, "%[^ :]", type->header_name) == EOF)
		test_error("fscanf(%s)", type->header_name);
	if (fread(&tmp, 1, 1, fnetstat) != 1 || tmp != ':')
		test_error("Unexpected netstat format (%c)", tmp);

	for (i = type->counters_nr; i < type->counters_nr + nr_elems; i++) {
		struct netstat_counter *nc = &type->counters[i];
		const char *new_pos = strchr(pos, ' ');
		const char *fmt = " %" PRIu64;

		if (new_pos == NULL)
			new_pos = strchr(pos, '\n');

		nc->name = strndup(pos, new_pos - pos);
		if (nc->name == NULL)
			test_error("strndup()");

		if (unlikely(!strcmp(nc->name, "MaxConn")))
			fmt = " %" PRId64; /* MaxConn is signed, RFC 2012 */
		if (fscanf(fnetstat, fmt, &nc->val) != 1)
			test_error("fscanf(%s)", nc->name);
		pos = new_pos + 1;
	}
	type->counters_nr += nr_elems;

	if (fread(&tmp, 1, 1, fnetstat) != 1 || tmp != '\n')
		test_error("Unexpected netstat format");
}

static const char *snmp6_name = "Snmp6";
static void snmp6_read(FILE *fnetstat, struct netstat **dest)
{
	struct netstat *type = lookup_get(*dest, snmp6_name, strlen(snmp6_name));
	char *counter_name;
	size_t i;

	for (i = type->counters_nr;; i++) {
		struct netstat_counter *nc;
		uint64_t counter;

		if (fscanf(fnetstat, "%ms", &counter_name) == EOF)
			break;
		if (fscanf(fnetstat, "%" PRIu64, &counter) == EOF)
			test_error("Unexpected snmp6 format");
		type->counters = reallocarray(type->counters, i + 1,
					sizeof(struct netstat_counter));
		if (!type->counters)
			test_error("reallocarray()");
		nc = &type->counters[i];
		nc->name = counter_name;
		nc->val = counter;
	}
	type->counters_nr = i;
	*dest = type;
}

struct netstat *netstat_read(void)
{
	struct netstat *ret = 0;
	size_t line_sz = 0;
	char *line = NULL;
	FILE *fnetstat;

	/*
	 * Opening thread-self instead of /proc/net/... as the latter
	 * points to /proc/self/net/ which instantiates thread-leader's
	 * net-ns, see:
	 * commit 155134fef2b6 ("Revert "proc: Point /proc/{mounts,net} at..")
	 */
	errno = 0;
	fnetstat = fopen("/proc/thread-self/net/netstat", "r");
	if (fnetstat == NULL)
		test_error("failed to open /proc/net/netstat");

	while (getline(&line, &line_sz, fnetstat) != -1)
		netstat_read_type(fnetstat, &ret, line);
	fclose(fnetstat);

	errno = 0;
	fnetstat = fopen("/proc/thread-self/net/snmp", "r");
	if (fnetstat == NULL)
		test_error("failed to open /proc/net/snmp");

	while (getline(&line, &line_sz, fnetstat) != -1)
		netstat_read_type(fnetstat, &ret, line);
	fclose(fnetstat);

	errno = 0;
	fnetstat = fopen("/proc/thread-self/net/snmp6", "r");
	if (fnetstat == NULL)
		test_error("failed to open /proc/net/snmp6");

	snmp6_read(fnetstat, &ret);
	fclose(fnetstat);

	free(line);
	return ret;
}

void netstat_free(struct netstat *ns)
{
	while (ns != NULL) {
		struct netstat *prev = ns;
		size_t i;

		free(ns->header_name);
		for (i = 0; i < ns->counters_nr; i++)
			free(ns->counters[i].name);
		free(ns->counters);
		ns = ns->next;
		free(prev);
	}
}

static inline void
__netstat_print_diff(uint64_t a, struct netstat *nsb, size_t i)
{
	if (unlikely(!strcmp(nsb->header_name, "MaxConn"))) {
		test_print("%8s %25s: %" PRId64 " => %" PRId64,
				nsb->header_name, nsb->counters[i].name,
				a, nsb->counters[i].val);
		return;
	}

	test_print("%8s %25s: %" PRIu64 " => %" PRIu64, nsb->header_name,
			nsb->counters[i].name, a, nsb->counters[i].val);
}

void netstat_print_diff(struct netstat *nsa, struct netstat *nsb)
{
	size_t i, j;

	while (nsb != NULL) {
		if (unlikely(strcmp(nsb->header_name, nsa->header_name))) {
			for (i = 0; i < nsb->counters_nr; i++)
				__netstat_print_diff(0, nsb, i);
			nsb = nsb->next;
			continue;
		}

		if (nsb->counters_nr < nsa->counters_nr)
			test_error("Unexpected: some counters disappeared!");

		for (j = 0, i = 0; i < nsb->counters_nr; i++) {
			if (strcmp(nsb->counters[i].name, nsa->counters[j].name)) {
				__netstat_print_diff(0, nsb, i);
				continue;
			}

			if (nsa->counters[j].val == nsb->counters[i].val) {
				j++;
				continue;
			}

			__netstat_print_diff(nsa->counters[j].val, nsb, i);
			j++;
		}
		if (j != nsa->counters_nr)
			test_error("Unexpected: some counters disappeared!");

		nsb = nsb->next;
		nsa = nsa->next;
	}
}

uint64_t netstat_get(struct netstat *ns, const char *name, bool *not_found)
{
	if (not_found)
		*not_found = false;

	while (ns != NULL) {
		size_t i;

		for (i = 0; i < ns->counters_nr; i++) {
			if (!strcmp(name, ns->counters[i].name))
				return ns->counters[i].val;
		}

		ns = ns->next;
	}

	if (not_found)
		*not_found = true;
	return 0;
}