// SPDX-License-Identifier: GPL-2.0
#include <linux/ptrace.h>
#include <stddef.h>
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_tracing.h>
#include "bpf_misc.h"
char _license[] SEC("license") = "GPL";
/* typically virtio scsi has max SGs of 6 */
#define VIRTIO_MAX_SGS 6
/* Verifier will fail with SG_MAX = 128. The failure can be
* workarounded with a smaller SG_MAX, e.g. 10.
*/
#define WORKAROUND
#ifdef WORKAROUND
#define SG_MAX 10
#else
/* typically virtio blk has max SEG of 128 */
#define SG_MAX 128
#endif
#define SG_CHAIN 0x01UL
#define SG_END 0x02UL
struct scatterlist {
unsigned long page_link;
unsigned int offset;
unsigned int length;
};
#define sg_is_chain(sg) ((sg)->page_link & SG_CHAIN)
#define sg_is_last(sg) ((sg)->page_link & SG_END)
#define sg_chain_ptr(sg) \
((struct scatterlist *) ((sg)->page_link & ~(SG_CHAIN | SG_END)))
static inline struct scatterlist *__sg_next(struct scatterlist *sgp)
{
struct scatterlist sg;
bpf_probe_read_kernel(&sg, sizeof(sg), sgp);
if (sg_is_last(&sg))
return NULL;
sgp++;
bpf_probe_read_kernel(&sg, sizeof(sg), sgp);
if (sg_is_chain(&sg))
sgp = sg_chain_ptr(&sg);
return sgp;
}
static inline struct scatterlist *get_sgp(struct scatterlist **sgs, int i)
{
struct scatterlist *sgp;
bpf_probe_read_kernel(&sgp, sizeof(sgp), sgs + i);
return sgp;
}
int config = 0;
int result = 0;
SEC("kprobe/virtqueue_add_sgs")
int BPF_KPROBE(trace_virtqueue_add_sgs, void *unused, struct scatterlist **sgs,
unsigned int out_sgs, unsigned int in_sgs)
{
struct scatterlist *sgp = NULL;
__u64 length1 = 0, length2 = 0;
unsigned int i, n, len;
if (config != 0)
return 0;
for (i = 0; (i < VIRTIO_MAX_SGS) && (i < out_sgs); i++) {
__sink(out_sgs);
for (n = 0, sgp = get_sgp(sgs, i); sgp && (n < SG_MAX);
sgp = __sg_next(sgp)) {
bpf_probe_read_kernel(&len, sizeof(len), &sgp->length);
length1 += len;
n++;
}
}
for (i = 0; (i < VIRTIO_MAX_SGS) && (i < in_sgs); i++) {
__sink(in_sgs);
for (n = 0, sgp = get_sgp(sgs, i); sgp && (n < SG_MAX);
sgp = __sg_next(sgp)) {
bpf_probe_read_kernel(&len, sizeof(len), &sgp->length);
length2 += len;
n++;
}
}
config = 1;
result = length2 - length1;
return 0;
}