diff --git a/cover-letter b/cover-letter new file mode 100644 index 0000000000000..aab9dc2aef957 --- /dev/null +++ b/cover-letter @@ -0,0 +1,160 @@ +I was recently looking into using BPF socket iterators in conjunction +with the bpf_sock_destroy() kfunc as a means to forcefully destroy a +set of UDP sockets connected to a deleted backend [1]. The intent is to +use BPF iterators + kfuncs in lieu of INET_DIAG infrastructure to +destroy sockets in order to simplify Cilium's system requirements. Aditi +describes the scenario in [2], the patch series that introduced +bpf_sock_destroy() for this very purpose: + +> This patch set adds the capability to destroy sockets in BPF. We plan +> to use the capability in Cilium to force client sockets to reconnect +> when their remote load-balancing backends are deleted. The other use +> case is on-the-fly policy enforcement where existing socket +> connections prevented by policies need to be terminated. + +One would want and expect an iterator to visit every socket that existed +before the iterator was created, if not exactly once, then at least +once, otherwise we could accidentally skip a socket that we intended to +destroy. With the iterator implementation as it exists today, this is +the behavior you would observe in the vast majority of cases. + +However, in the process of reviewing [2] and some follow up fixes to +bpf_iter_udp_batch() ([3] [4]) by Martin, it occurred to me that there +are situations where BPF socket iterators may repeat, or worse, skip +sockets altogether even if they existed prior to iterator creation, +making BPF iterators as a mechanism to achieve the goal stated above +slightly buggy. + +This RFC highlights some of these scenarios, extending +prog_tests/sock_iter_batch.c to illustrate conditions under which +sockets can be skipped or repeated, and proposes a solution for +achieving exactly-once semantics for socket iterators in all cases as +it relates to sockets that existed prior to the start of iteration. + +I'm hoping to raise awareness of this issue generally if +it's not already common knowledge and get some feedback on the viability +of the proposed improvement. + +THE PROBLEM +=========== +Both UDP and TCP socket iterators use iter->offset to track progress +through a bucket, which is a measure of the number of matching sockets +from the current bucket that have been seen or processed by the +iterator. On subsequent iterations, if the current bucket has +unprocessed items, we skip at least iter->offset matching items in the +bucket before adding any remaining items to the next batch. The intent +seems to be to skip any items we've already seen, but iter->offset +isn't always an accurate measure of "things already seen". There are a +variety of scenarios where the underlying bucket changes between reads, +leading to either repeated or skipped sockets. Two such scenarios are +illustrated below and reproduced by the self tests. + +Skip A Socket ++------+--------------------+--------------+---------------+ +| Time | Event | Bucket State | Bucket Offset | ++------+--------------------+--------------+---------------+ +| 1 | read(iter_fd) -> A | A->B->C->D | 1 | +| 2 | close(A) | B->C->D | 1 | +| 3 | read(iter_fd) -> C | B->C->D | 2 | +| 4 | read(iter_fd) -> D | B->C->D | 3 | +| 5 | read(iter_fd) -> 0 | B->C->D | - | ++------+--------------------+--------------+---------------+ + +Iteration sees these buckets: [A, C, D] +B is skipped. + +Repeat A Socket ++------+--------------------+---------------+---------------+ +| Time | Event | Bucket State | Bucket Offset | ++------+--------------------+---------------+---------------+ +| 1 | read(iter_fd) -> A | A->B->C->D | 1 | +| 2 | connect(E) | E->A->B->C->D | 1 | +| 3 | read(iter_fd) -> A | E->A->B->C->D | 2 | +| 3 | read(iter_fd) -> B | E->A->B->C->D | 3 | +| 3 | read(iter_fd) -> C | E->A->B->C->D | 4 | +| 4 | read(iter_fd) -> D | E->A->B->C->D | 5 | +| 5 | read(iter_fd) -> 0 | E->A->B->C->D | - | ++------+--------------------+---------------+---------------+ + +Iteration sees these buckets: [A, A, B, C, D] +A is repeated. + +If we consider corner cases like these, semantics are neither +at-most-once, nor at-least-once, nor exactly-once. Repeating a socket +during iteration is perhaps less problematic than skipping it +altogether as long as the BPF program is aware that duplicates are +possible; however, in an ideal world, we could process each socket +exactly once. There are some constraints that make this a bit more +difficult: + +1) Despite batch resize attempts inside both bpf_iter_udp_batch() and + bpf_iter_tcp_batch(), we have to deal with the possibility that our + batch size cannot contain all items in a bucket at once. +2) We cannot hold a lock on the bucket between iterations, meaning that + the structure can change in lots of interesting ways. + +PROPOSAL +======== +Can we achieve exactly-once semantics for socket iterators even in the +face of concurrent additions or removals to the current bucket? If we +ignore the possibility of signed 64 bit rollover, then yes. This +series replaces the current offset-based scheme used for progress +tracking with a scheme based on a monotonically increasing version +number. It works as follows: + +* Assign index numbers on sockets in the bucket's linked list such that + they are monotonically increasing as you read from the head to tail. + + * Every time a socket is added to a bucket, increment the hash + table's version number, ver. + * If the socket is being added to the head of the bucket's linked + list, set sk->idx to -1*ver. + * If the socket is being added to the tail of the bucket's linked + list, set sk->ver to ver. + + Ex: append_head(C), append_head(B), append_tail(D), append_head(A), + append_tail(E) results in the following state. + + A -> B -> C -> D -> E + -4 -2 -1 3 5 +* As we iterate through a bucket, keep track of the last index number + we've seen for that bucket, iter->prev_idx. +* On subsequent iterations, skip ahead in the bucket until we see a + socket whose index, sk->idx, is greater than iter->prev_idx. + +Since we always iterate from head to tail and indexes are always +increasing in that direction, we can be sure that any socket whose index +is greater than iter->prev_idx has not yet been seen. Any socket whose +index is less than or equal to iter->prev_idx has either been seen +before or was added since we last saw that bucket. In either case, it's +safe to skip them (any new sockets did not exist when we created the +iterator). + +SOME ALTERNATIVES +================= +1. One alternative I considered was simply counting the number of + removals that have occurred per bucket, remembering this between + calls to bpf_iter_(tcp|udp)_batch() as part of the iterator state, + then using it to detect if it has changed. If any removals have + occurred, we would need to walk back iter->offset by at least that + much to avoid skips. This approach is simpler but may repeat sockets. +2. Don't allow partial batches; always make sure we capture all sockets + in a bucket in a batch. bpf_iter_(tcp|udp)_batch() already have some + logic to try one time to resize the batch, but as far as I know, + this isn't viable since we have to contend with the fact that + bpf_iter_(tcp|udp)_realloc_batch() may not be able to grab more + memory. + +Anyway, maybe everyone already knows this can happen and isn't +overly concerned, since the possibility of skips or repeats is small, +but I thought I'd highlight the possibility just in case. It certainly +seems like something we'd want to avoid if we can help it, and with a +few adjustments, we can. + +-Jordan + +[1]: https://github.com/cilium/cilium/issues/37907 +[2]: https://lore.kernel.org/bpf/20230519225157.760788-1-aditi.ghag@isovalent.com/ +[3]: https://lore.kernel.org/netdev/20240112190530.3751661-1-martin.lau@linux.dev/ +[4]: https://lore.kernel.org/netdev/20240112190530.3751661-2-martin.lau@linux.dev/ + diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h index 5eea47f135a42..c95d3b1da1990 100644 --- a/include/net/inet_hashtables.h +++ b/include/net/inet_hashtables.h @@ -172,6 +172,8 @@ struct inet_hashinfo { struct inet_listen_hashbucket *lhash2; bool pernet; + + atomic64_t ver; } ____cacheline_aligned_in_smp; static inline struct inet_hashinfo *tcp_or_dccp_get_hashinfo(const struct sock *sk) diff --git a/include/net/sock.h b/include/net/sock.h index 8036b3b79cd8b..b11f43e8e7ec7 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -228,6 +228,7 @@ struct sock_common { u32 skc_window_clamp; u32 skc_tw_snd_nxt; /* struct tcp_timewait_sock */ }; + __s64 skc_idx; /* public: */ }; @@ -378,6 +379,7 @@ struct sock { #define sk_incoming_cpu __sk_common.skc_incoming_cpu #define sk_flags __sk_common.skc_flags #define sk_rxhash __sk_common.skc_rxhash +#define sk_idx __sk_common.skc_idx __cacheline_group_begin(sock_write_rx); diff --git a/include/net/tcp.h b/include/net/tcp.h index 2d08473a6dc00..499acd6da35f8 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -2202,7 +2202,8 @@ struct tcp_iter_state { struct seq_net_private p; enum tcp_seq_states state; struct sock *syn_wait_sk; - int bucket, offset, sbucket, num; + int bucket, sbucket, num; + __s64 prev_idx; loff_t last_pos; }; diff --git a/include/net/udp.h b/include/net/udp.h index 6e89520e100dc..9398561addc68 100644 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -102,6 +102,7 @@ struct udp_table { #endif unsigned int mask; unsigned int log; + atomic64_t ver; }; extern struct udp_table udp_table; void udp_table_init(struct udp_table *, const char *); diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c index 9bfcfd016e182..bc9f58172790f 100644 --- a/net/ipv4/inet_hashtables.c +++ b/net/ipv4/inet_hashtables.c @@ -534,6 +534,12 @@ struct sock *__inet_lookup_established(const struct net *net, } EXPORT_SYMBOL_GPL(__inet_lookup_established); +static inline __s64 inet_hashinfo_next_idx(struct inet_hashinfo *hinfo, + bool pos) +{ + return (pos ? 1 : -1) * atomic64_inc_return(&hinfo->ver); +} + /* called with local bh disabled */ static int __inet_check_established(struct inet_timewait_death_row *death_row, struct sock *sk, __u16 lport, @@ -581,6 +587,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row, sk->sk_hash = hash; WARN_ON(!sk_unhashed(sk)); __sk_nulls_add_node_rcu(sk, &head->chain); + sk->sk_idx = inet_hashinfo_next_idx(hinfo, false); if (tw) { sk_nulls_del_node_init_rcu((struct sock *)tw); __NET_INC_STATS(net, LINUX_MIB_TIMEWAITRECYCLED); @@ -678,8 +685,10 @@ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) ret = false; } - if (ret) + if (ret) { __sk_nulls_add_node_rcu(sk, list); + sk->sk_idx = inet_hashinfo_next_idx(hashinfo, false); + } spin_unlock(lock); @@ -729,6 +738,7 @@ int __inet_hash(struct sock *sk, struct sock *osk) { struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); struct inet_listen_hashbucket *ilb2; + bool add_tail; int err = 0; if (sk->sk_state != TCP_LISTEN) { @@ -747,11 +757,13 @@ int __inet_hash(struct sock *sk, struct sock *osk) goto unlock; } sock_set_flag(sk, SOCK_RCU_FREE); - if (IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport && - sk->sk_family == AF_INET6) + add_tail = IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport && + sk->sk_family == AF_INET6; + if (add_tail) __sk_nulls_add_node_tail_rcu(sk, &ilb2->nulls_head); else __sk_nulls_add_node_rcu(sk, &ilb2->nulls_head); + sk->sk_idx = inet_hashinfo_next_idx(hashinfo, add_tail); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); unlock: spin_unlock(&ilb2->lock); diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 285678d8ce077..63693af0c05c8 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -5147,6 +5147,7 @@ void __init tcp_init(void) cnt = tcp_hashinfo.ehash_mask + 1; sysctl_tcp_max_orphans = cnt / 2; + atomic64_set(&tcp_hashinfo.ver, 0); tcp_init_mem(); /* Set per-socket limits to no more than 1/128 the pressure threshold */ diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 2632844d2c356..27d124266c7b0 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -2602,7 +2602,7 @@ static void *listening_get_first(struct seq_file *seq) struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; - st->offset = 0; + st->prev_idx = 0; for (; st->bucket <= hinfo->lhash2_mask; st->bucket++) { struct inet_listen_hashbucket *ilb2; struct hlist_nulls_node *node; @@ -2637,7 +2637,7 @@ static void *listening_get_next(struct seq_file *seq, void *cur) struct sock *sk = cur; ++st->num; - ++st->offset; + st->prev_idx = sk->sk_idx; sk = sk_nulls_next(sk); sk_nulls_for_each_from(sk, node) { @@ -2658,7 +2658,6 @@ static void *listening_get_idx(struct seq_file *seq, loff_t *pos) void *rc; st->bucket = 0; - st->offset = 0; rc = listening_get_first(seq); while (rc && *pos) { @@ -2683,7 +2682,7 @@ static void *established_get_first(struct seq_file *seq) struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; - st->offset = 0; + st->prev_idx = 0; for (; st->bucket <= hinfo->ehash_mask; ++st->bucket) { struct sock *sk; struct hlist_nulls_node *node; @@ -2714,7 +2713,6 @@ static void *established_get_next(struct seq_file *seq, void *cur) struct sock *sk = cur; ++st->num; - ++st->offset; sk = sk_nulls_next(sk); @@ -2763,8 +2761,8 @@ static void *tcp_seek_last_pos(struct seq_file *seq) { struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; + __s64 prev_idx = st->prev_idx; int bucket = st->bucket; - int offset = st->offset; int orig_num = st->num; void *rc = NULL; @@ -2773,18 +2771,21 @@ static void *tcp_seek_last_pos(struct seq_file *seq) if (st->bucket > hinfo->lhash2_mask) break; rc = listening_get_first(seq); - while (offset-- && rc && bucket == st->bucket) + while (rc && bucket == st->bucket && prev_idx && + ((struct sock *)rc)->sk_idx <= prev_idx) rc = listening_get_next(seq, rc); if (rc) break; st->bucket = 0; + prev_idx = 0; st->state = TCP_SEQ_STATE_ESTABLISHED; fallthrough; case TCP_SEQ_STATE_ESTABLISHED: if (st->bucket > hinfo->ehash_mask) break; rc = established_get_first(seq); - while (offset-- && rc && bucket == st->bucket) + while (rc && bucket == st->bucket && prev_idx && + ((struct sock *)rc)->sk_idx <= prev_idx) rc = established_get_next(seq, rc); } @@ -2807,7 +2808,7 @@ void *tcp_seq_start(struct seq_file *seq, loff_t *pos) st->state = TCP_SEQ_STATE_LISTENING; st->num = 0; st->bucket = 0; - st->offset = 0; + st->prev_idx = 0; rc = *pos ? tcp_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; out: @@ -2832,7 +2833,7 @@ void *tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos) if (!rc) { st->state = TCP_SEQ_STATE_ESTABLISHED; st->bucket = 0; - st->offset = 0; + st->prev_idx = 0; rc = established_get_first(seq); } break; @@ -3124,7 +3125,7 @@ static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) * it has to advance to the next bucket. */ if (iter->st_bucket_done) { - st->offset = 0; + st->prev_idx = 0; st->bucket++; if (st->state == TCP_SEQ_STATE_LISTENING && st->bucket > hinfo->lhash2_mask) { @@ -3192,8 +3193,9 @@ static void *bpf_iter_tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos) * the future start() will resume at st->offset in * st->bucket. See tcp_seek_last_pos(). */ - st->offset++; - sock_gen_put(iter->batch[iter->cur_sk++]); + sk = iter->batch[iter->cur_sk++]; + st->prev_idx = sk->sk_idx; + sock_gen_put(sk); } if (iter->cur_sk < iter->end_sk) diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index a9bb9ce5438ea..d7e9b33469838 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -229,6 +229,11 @@ static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot) return reuseport_alloc(sk, inet_rcv_saddr_any(sk)); } +static inline __s64 udp_table_next_idx(struct udp_table *udptable, bool pos) +{ + return (pos ? 1 : -1) * atomic64_inc_return(&udptable->ver); +} + /** * udp_lib_get_port - UDP/-Lite port lookup for IPv4 and IPv6 * @@ -244,6 +249,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum, struct udp_hslot *hslot, *hslot2; struct net *net = sock_net(sk); int error = -EADDRINUSE; + bool add_tail; if (!snum) { DECLARE_BITMAP(bitmap, PORTS_PER_CHAIN); @@ -335,14 +341,16 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum, hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash); spin_lock(&hslot2->lock); - if (IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport && - sk->sk_family == AF_INET6) + add_tail = IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport && + sk->sk_family == AF_INET6; + if (add_tail) hlist_add_tail_rcu(&udp_sk(sk)->udp_portaddr_node, &hslot2->head); else hlist_add_head_rcu(&udp_sk(sk)->udp_portaddr_node, &hslot2->head); hslot2->count++; + sk->sk_idx = udp_table_next_idx(udptable, add_tail); spin_unlock(&hslot2->lock); } @@ -2250,6 +2258,8 @@ void udp_lib_rehash(struct sock *sk, u16 newhash, u16 newhash4) hlist_add_head_rcu(&udp_sk(sk)->udp_portaddr_node, &nhslot2->head); nhslot2->count++; + sk->sk_idx = udp_table_next_idx(udptable, + false); spin_unlock(&nhslot2->lock); } @@ -3390,9 +3400,9 @@ struct bpf_udp_iter_state { unsigned int cur_sk; unsigned int end_sk; unsigned int max_sk; - int offset; struct sock **batch; bool st_bucket_done; + __s64 prev_idx; }; static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter, @@ -3402,14 +3412,13 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq) struct bpf_udp_iter_state *iter = seq->private; struct udp_iter_state *state = &iter->state; struct net *net = seq_file_net(seq); - int resume_bucket, resume_offset; struct udp_table *udptable; unsigned int batch_sks = 0; bool resized = false; + int resume_bucket; struct sock *sk; resume_bucket = state->bucket; - resume_offset = iter->offset; /* The current batch is done, so advance the bucket. */ if (iter->st_bucket_done) @@ -3436,18 +3445,19 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq) if (hlist_empty(&hslot2->head)) continue; - iter->offset = 0; spin_lock_bh(&hslot2->lock); + /* Reset prev_idx if this is a new bucket. */ + if (!resume_bucket || state->bucket != resume_bucket) + iter->prev_idx = 0; udp_portaddr_for_each_entry(sk, &hslot2->head) { if (seq_sk_match(seq, sk)) { - /* Resume from the last iterated socket at the - * offset in the bucket before iterator was stopped. + /* Resume from the first socket that we didn't + * see last time around. */ if (state->bucket == resume_bucket && - iter->offset < resume_offset) { - ++iter->offset; + iter->prev_idx && + sk->sk_idx <= iter->prev_idx) continue; - } if (iter->end_sk < iter->max_sk) { sock_hold(sk); iter->batch[iter->end_sk++] = sk; @@ -3492,8 +3502,9 @@ static void *bpf_iter_udp_seq_next(struct seq_file *seq, void *v, loff_t *pos) * done with seq_show(), so unref the iter->cur_sk. */ if (iter->cur_sk < iter->end_sk) { - sock_put(iter->batch[iter->cur_sk++]); - ++iter->offset; + sk = iter->batch[iter->cur_sk++]; + iter->prev_idx = sk->sk_idx; + sock_put(sk); } /* After updating iter->cur_sk, check if there are more sockets @@ -3740,6 +3751,7 @@ static struct udp_table __net_init *udp_pernet_table_alloc(unsigned int hash_ent udptable->hash2 = (void *)(udptable->hash + hash_entries); udptable->mask = hash_entries - 1; udptable->log = ilog2(hash_entries); + atomic64_set(&udptable->ver, 0); for (i = 0; i < hash_entries; i++) { INIT_HLIST_HEAD(&udptable->hash[i].head); diff --git a/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c b/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c index d56e18b255280..414c623f1fa06 100644 --- a/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c +++ b/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c @@ -6,15 +6,275 @@ #include "sock_iter_batch.skel.h" #define TEST_NS "sock_iter_batch_netns" +#define nr_soreuse 4 -static const int nr_soreuse = 4; +static const __u16 reuse_port = 10001; + +struct iter_out { + int idx; + __u64 cookie; +} __attribute__((__packed__)); + +struct sock_count { + __u64 cookie; + int count; +}; + +static int insert(__u64 cookie, struct sock_count counts[], int counts_len) +{ + int insert = -1; + int i = 0; + + for (; i < counts_len; i++) { + if (!counts[i].cookie) { + insert = i; + } else if (counts[i].cookie == cookie) { + insert = i; + break; + } + } + if (insert < 0) + return insert; + + counts[insert].cookie = cookie; + counts[insert].count++; + + return counts[insert].count; +} + +static int read_n(int iter_fd, int n, struct sock_count counts[], + int counts_len) +{ + struct iter_out out; + int nread = 1; + int i = 0; + + for (; nread > 0 && (n < 0 || i < n); i++) { + nread = read(iter_fd, &out, sizeof(out)); + if (!nread || !ASSERT_GE(nread, 1, "nread")) + break; + ASSERT_GE(insert(out.cookie, counts, counts_len), 0, "insert"); + } + + ASSERT_TRUE(n < 0 || i == n, "n < 0 || i == n"); + + return i; +} + +static __u64 socket_cookie(int fd) +{ + __u64 cookie; + socklen_t cookie_len = sizeof(cookie); + static __u32 duration; /* for CHECK macro */ + + if (CHECK(getsockopt(fd, SOL_SOCKET, SO_COOKIE, &cookie, &cookie_len) < 0, + "getsockopt(SO_COOKIE)", "%s\n", strerror(errno))) + return 0; + return cookie; +} + +static bool was_seen(int fd, struct sock_count counts[], int counts_len) +{ + __u64 cookie = socket_cookie(fd); + int i = 0; + + for (; cookie && i < counts_len; i++) + if (cookie == counts[i].cookie) + return true; + + return false; +} + +static int get_seen_socket(int *fds, struct sock_count counts[], int n) +{ + int i = 0; + + for (; i < n; i++) + if (was_seen(fds[i], counts, n)) + return i; + return -1; +} + +static int get_seen_count(int fd, struct sock_count counts[], int n) +{ + __u64 cookie = socket_cookie(fd); + int count = 0; + int i = 0; + + for (; cookie && !count && i < n; i++) + if (cookie == counts[i].cookie) + count = counts[i].count; + + return count; +} + +static void check_n_were_seen_once(int *fds, int fds_len, int n, + struct sock_count counts[], int counts_len) +{ + int seen_once = 0; + int seen_cnt; + int i = 0; + + for (; i < fds_len; i++) { + /* Skip any sockets that were closed or that weren't seen + * exactly once. + */ + if (fds[i] < 0) + continue; + seen_cnt = get_seen_count(fds[i], counts, counts_len); + if (seen_cnt && ASSERT_EQ(seen_cnt, 1, "seen_cnt")) + seen_once++; + } + + ASSERT_EQ(seen_once, n, "seen_once"); +} + +static void do_skip_test(int sock_type) +{ + struct sock_count counts[nr_soreuse] = {}; + struct bpf_link *link = NULL; + struct sock_iter_batch *skel; + int err, iter_fd = -1; + int close_idx; + int *fds; + + skel = sock_iter_batch__open(); + if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open")) + return; + + /* Prepare a bucket of sockets in the kernel hashtable */ + int local_port; + + fds = start_reuseport_server(AF_INET, sock_type, "127.0.0.1", 0, 0, + nr_soreuse); + if (!ASSERT_OK_PTR(fds, "start_reuseport_server")) + goto done; + local_port = get_socket_local_port(*fds); + if (!ASSERT_GE(local_port, 0, "get_socket_local_port")) + goto done; + skel->rodata->ports[0] = ntohs(local_port); + skel->rodata->sf = AF_INET; + + err = sock_iter_batch__load(skel); + if (!ASSERT_OK(err, "sock_iter_batch__load")) + goto done; + + link = bpf_program__attach_iter(sock_type == SOCK_STREAM ? + skel->progs.iter_tcp_soreuse : + skel->progs.iter_udp_soreuse, + NULL); + if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter")) + goto done; + + iter_fd = bpf_iter_create(bpf_link__fd(link)); + if (!ASSERT_GE(iter_fd, 0, "bpf_iter_create")) + goto done; + + /* Iterate through the first three sockets. */ + read_n(iter_fd, nr_soreuse - 1, counts, nr_soreuse); + + /* Make sure we saw three sockets from fds exactly once. */ + check_n_were_seen_once(fds, nr_soreuse, nr_soreuse - 1, counts, + nr_soreuse); + + /* Close a socket we've already seen to remove it from the bucket. */ + close_idx = get_seen_socket(fds, counts, nr_soreuse); + if (!ASSERT_GE(close_idx, 0, "close_idx")) + goto done; + close(fds[close_idx]); + fds[close_idx] = -1; + + /* Iterate through the rest of the sockets. */ + read_n(iter_fd, -1, counts, nr_soreuse); + + /* Make sure the last socket wasn't skipped and that there were no + * repeats. + */ + check_n_were_seen_once(fds, nr_soreuse, nr_soreuse - 1, counts, + nr_soreuse); +done: + free_fds(fds, nr_soreuse); + if (iter_fd < 0) + close(iter_fd); + bpf_link__destroy(link); + sock_iter_batch__destroy(skel); +} + +static void do_repeat_test(int sock_type) +{ + struct sock_count counts[nr_soreuse] = {}; + struct bpf_link *link = NULL; + struct sock_iter_batch *skel; + int err, i, iter_fd = -1; + int *fds[2] = {}; + + skel = sock_iter_batch__open(); + if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open")) + return; + + /* Prepare a bucket of sockets in the kernel hashtable */ + int local_port; + + fds[0] = start_reuseport_server(AF_INET, sock_type, "127.0.0.1", + reuse_port, 0, nr_soreuse); + if (!ASSERT_OK_PTR(fds[0], "start_reuseport_server")) + goto done; + local_port = get_socket_local_port(*fds[0]); + if (!ASSERT_GE(local_port, 0, "get_socket_local_port")) + goto done; + skel->rodata->ports[0] = ntohs(local_port); + skel->rodata->sf = AF_INET; + + err = sock_iter_batch__load(skel); + if (!ASSERT_OK(err, "sock_iter_batch__load")) + goto done; + + link = bpf_program__attach_iter(sock_type == SOCK_STREAM ? + skel->progs.iter_tcp_soreuse : + skel->progs.iter_udp_soreuse, + NULL); + if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter")) + goto done; + + iter_fd = bpf_iter_create(bpf_link__fd(link)); + if (!ASSERT_GE(iter_fd, 0, "bpf_iter_create")) + goto done; + + /* Iterate through the first three sockets */ + read_n(iter_fd, nr_soreuse - 1, counts, nr_soreuse); + + /* Make sure we saw three sockets from fds exactly once. */ + check_n_were_seen_once(fds[0], nr_soreuse, nr_soreuse - 1, counts, + nr_soreuse); + + /* Add nr_soreuse more sockets to the bucket. */ + fds[1] = start_reuseport_server(AF_INET, sock_type, "127.0.0.1", + reuse_port, 0, nr_soreuse); + if (!ASSERT_OK_PTR(fds[1], "start_reuseport_server")) + goto done; + + /* Iterate through the rest of the sockets. */ + read_n(iter_fd, -1, counts, nr_soreuse); + + /* Make sure each socket from the first set was seen exactly once. */ + check_n_were_seen_once(fds[0], nr_soreuse, nr_soreuse, counts, + nr_soreuse); +done: + for (i = 0; i < ARRAY_SIZE(fds); i++) + free_fds(fds[i], nr_soreuse); + if (iter_fd < 0) + close(iter_fd); + bpf_link__destroy(link); + sock_iter_batch__destroy(skel); +} static void do_test(int sock_type, bool onebyone) { int err, i, nread, to_read, total_read, iter_fd = -1; - int first_idx, second_idx, indices[nr_soreuse]; + struct iter_out outputs[nr_soreuse]; struct bpf_link *link = NULL; struct sock_iter_batch *skel; + int first_idx, second_idx; int *fds[2] = {}; skel = sock_iter_batch__open(); @@ -34,6 +294,7 @@ static void do_test(int sock_type, bool onebyone) goto done; skel->rodata->ports[i] = ntohs(local_port); } + skel->rodata->sf = AF_INET6; err = sock_iter_batch__load(skel); if (!ASSERT_OK(err, "sock_iter_batch__load")) @@ -55,38 +316,38 @@ static void do_test(int sock_type, bool onebyone) * from a bucket and leave one socket out from * that bucket on purpose. */ - to_read = (nr_soreuse - 1) * sizeof(*indices); + to_read = (nr_soreuse - 1) * sizeof(*outputs); total_read = 0; first_idx = -1; do { - nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read); - if (nread <= 0 || nread % sizeof(*indices)) + nread = read(iter_fd, outputs, onebyone ? sizeof(*outputs) : to_read); + if (nread <= 0 || nread % sizeof(*outputs)) break; total_read += nread; if (first_idx == -1) - first_idx = indices[0]; - for (i = 0; i < nread / sizeof(*indices); i++) - ASSERT_EQ(indices[i], first_idx, "first_idx"); + first_idx = outputs[0].idx; + for (i = 0; i < nread / sizeof(*outputs); i++) + ASSERT_EQ(outputs[i].idx, first_idx, "first_idx"); } while (total_read < to_read); - ASSERT_EQ(nread, onebyone ? sizeof(*indices) : to_read, "nread"); + ASSERT_EQ(nread, onebyone ? sizeof(*outputs) : to_read, "nread"); ASSERT_EQ(total_read, to_read, "total_read"); free_fds(fds[first_idx], nr_soreuse); fds[first_idx] = NULL; /* Read the "whole" second bucket */ - to_read = nr_soreuse * sizeof(*indices); + to_read = nr_soreuse * sizeof(*outputs); total_read = 0; second_idx = !first_idx; do { - nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read); - if (nread <= 0 || nread % sizeof(*indices)) + nread = read(iter_fd, outputs, onebyone ? sizeof(*outputs) : to_read); + if (nread <= 0 || nread % sizeof(*outputs)) break; total_read += nread; - for (i = 0; i < nread / sizeof(*indices); i++) - ASSERT_EQ(indices[i], second_idx, "second_idx"); + for (i = 0; i < nread / sizeof(*outputs); i++) + ASSERT_EQ(outputs[i].idx, second_idx, "second_idx"); } while (total_read <= to_read); ASSERT_EQ(nread, 0, "nread"); /* Both so_reuseport ports should be in different buckets, so @@ -123,10 +384,14 @@ void test_sock_iter_batch(void) if (test__start_subtest("tcp")) { do_test(SOCK_STREAM, true); do_test(SOCK_STREAM, false); + do_skip_test(SOCK_STREAM); + do_repeat_test(SOCK_STREAM); } if (test__start_subtest("udp")) { do_test(SOCK_DGRAM, true); do_test(SOCK_DGRAM, false); + do_skip_test(SOCK_DGRAM); + do_repeat_test(SOCK_DGRAM); } close_netns(nstoken); diff --git a/tools/testing/selftests/bpf/progs/bpf_tracing_net.h b/tools/testing/selftests/bpf/progs/bpf_tracing_net.h index 59843b430f76a..82928cc5d87b7 100644 --- a/tools/testing/selftests/bpf/progs/bpf_tracing_net.h +++ b/tools/testing/selftests/bpf/progs/bpf_tracing_net.h @@ -123,6 +123,7 @@ #define sk_refcnt __sk_common.skc_refcnt #define sk_state __sk_common.skc_state #define sk_net __sk_common.skc_net +#define sk_rcv_saddr __sk_common.skc_rcv_saddr #define sk_v6_daddr __sk_common.skc_v6_daddr #define sk_v6_rcv_saddr __sk_common.skc_v6_rcv_saddr #define sk_flags __sk_common.skc_flags diff --git a/tools/testing/selftests/bpf/progs/sock_iter_batch.c b/tools/testing/selftests/bpf/progs/sock_iter_batch.c index 96531b0d9d55b..8f483337e103c 100644 --- a/tools/testing/selftests/bpf/progs/sock_iter_batch.c +++ b/tools/testing/selftests/bpf/progs/sock_iter_batch.c @@ -17,6 +17,12 @@ static bool ipv6_addr_loopback(const struct in6_addr *a) a->s6_addr32[2] | (a->s6_addr32[3] ^ bpf_htonl(1))) == 0; } +static bool ipv4_addr_loopback(__be32 a) +{ + return a == bpf_ntohl(0x7f000001); +} + +volatile const unsigned int sf; volatile const __u16 ports[2]; unsigned int bucket[2]; @@ -26,16 +32,20 @@ int iter_tcp_soreuse(struct bpf_iter__tcp *ctx) struct sock *sk = (struct sock *)ctx->sk_common; struct inet_hashinfo *hinfo; unsigned int hash; + __u64 sock_cookie; struct net *net; int idx; if (!sk) return 0; + sock_cookie = bpf_get_socket_cookie(sk); sk = bpf_core_cast(sk, struct sock); - if (sk->sk_family != AF_INET6 || + if (sk->sk_family != sf || sk->sk_state != TCP_LISTEN || - !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr)) + sk->sk_family == AF_INET6 ? + !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr) : + !ipv4_addr_loopback(sk->sk_rcv_saddr)) return 0; if (sk->sk_num == ports[0]) @@ -52,6 +62,7 @@ int iter_tcp_soreuse(struct bpf_iter__tcp *ctx) hinfo = net->ipv4.tcp_death_row.hashinfo; bucket[idx] = hash & hinfo->lhash2_mask; bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx)); + bpf_seq_write(ctx->meta->seq, &sock_cookie, sizeof(sock_cookie)); return 0; } @@ -63,14 +74,18 @@ int iter_udp_soreuse(struct bpf_iter__udp *ctx) { struct sock *sk = (struct sock *)ctx->udp_sk; struct udp_table *udptable; + __u64 sock_cookie; int idx; if (!sk) return 0; + sock_cookie = bpf_get_socket_cookie(sk); sk = bpf_core_cast(sk, struct sock); - if (sk->sk_family != AF_INET6 || - !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr)) + if (sk->sk_family != sf || + sk->sk_family == AF_INET6 ? + !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr) : + !ipv4_addr_loopback(sk->sk_rcv_saddr)) return 0; if (sk->sk_num == ports[0]) @@ -84,6 +99,7 @@ int iter_udp_soreuse(struct bpf_iter__udp *ctx) udptable = sk->sk_net.net->ipv4.udp_table; bucket[idx] = udp_sk(sk)->udp_portaddr_hash & udptable->mask; bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx)); + bpf_seq_write(ctx->meta->seq, &sock_cookie, sizeof(sock_cookie)); return 0; }