Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test PR To Trigger CI #8667

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions cover-letter
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
I was recently looking into using BPF socket iterators in conjunction
with the bpf_sock_destroy() kfunc as a means to forcefully destroy a
set of UDP sockets connected to a deleted backend [1]. The intent is to
use BPF iterators + kfuncs in lieu of INET_DIAG infrastructure to
destroy sockets in order to simplify Cilium's system requirements. Aditi
describes the scenario in [2], the patch series that introduced
bpf_sock_destroy() for this very purpose:

> This patch set adds the capability to destroy sockets in BPF. We plan
> to use the capability in Cilium to force client sockets to reconnect
> when their remote load-balancing backends are deleted. The other use
> case is on-the-fly policy enforcement where existing socket
> connections prevented by policies need to be terminated.

One would want and expect an iterator to visit every socket that existed
before the iterator was created, if not exactly once, then at least
once, otherwise we could accidentally skip a socket that we intended to
destroy. With the iterator implementation as it exists today, this is
the behavior you would observe in the vast majority of cases.

However, in the process of reviewing [2] and some follow up fixes to
bpf_iter_udp_batch() ([3] [4]) by Martin, it occurred to me that there
are situations where BPF socket iterators may repeat, or worse, skip
sockets altogether even if they existed prior to iterator creation,
making BPF iterators as a mechanism to achieve the goal stated above
slightly buggy.

This RFC highlights some of these scenarios, extending
prog_tests/sock_iter_batch.c to illustrate conditions under which
sockets can be skipped or repeated, and proposes a solution for
achieving exactly-once semantics for socket iterators in all cases as
it relates to sockets that existed prior to the start of iteration.

I'm hoping to raise awareness of this issue generally if
it's not already common knowledge and get some feedback on the viability
of the proposed improvement.

THE PROBLEM
===========
Both UDP and TCP socket iterators use iter->offset to track progress
through a bucket, which is a measure of the number of matching sockets
from the current bucket that have been seen or processed by the
iterator. On subsequent iterations, if the current bucket has
unprocessed items, we skip at least iter->offset matching items in the
bucket before adding any remaining items to the next batch. The intent
seems to be to skip any items we've already seen, but iter->offset
isn't always an accurate measure of "things already seen". There are a
variety of scenarios where the underlying bucket changes between reads,
leading to either repeated or skipped sockets. Two such scenarios are
illustrated below and reproduced by the self tests.

Skip A Socket
+------+--------------------+--------------+---------------+
| Time | Event | Bucket State | Bucket Offset |
+------+--------------------+--------------+---------------+
| 1 | read(iter_fd) -> A | A->B->C->D | 1 |
| 2 | close(A) | B->C->D | 1 |
| 3 | read(iter_fd) -> C | B->C->D | 2 |
| 4 | read(iter_fd) -> D | B->C->D | 3 |
| 5 | read(iter_fd) -> 0 | B->C->D | - |
+------+--------------------+--------------+---------------+

Iteration sees these buckets: [A, C, D]
B is skipped.

Repeat A Socket
+------+--------------------+---------------+---------------+
| Time | Event | Bucket State | Bucket Offset |
+------+--------------------+---------------+---------------+
| 1 | read(iter_fd) -> A | A->B->C->D | 1 |
| 2 | connect(E) | E->A->B->C->D | 1 |
| 3 | read(iter_fd) -> A | E->A->B->C->D | 2 |
| 3 | read(iter_fd) -> B | E->A->B->C->D | 3 |
| 3 | read(iter_fd) -> C | E->A->B->C->D | 4 |
| 4 | read(iter_fd) -> D | E->A->B->C->D | 5 |
| 5 | read(iter_fd) -> 0 | E->A->B->C->D | - |
+------+--------------------+---------------+---------------+

Iteration sees these buckets: [A, A, B, C, D]
A is repeated.

If we consider corner cases like these, semantics are neither
at-most-once, nor at-least-once, nor exactly-once. Repeating a socket
during iteration is perhaps less problematic than skipping it
altogether as long as the BPF program is aware that duplicates are
possible; however, in an ideal world, we could process each socket
exactly once. There are some constraints that make this a bit more
difficult:

1) Despite batch resize attempts inside both bpf_iter_udp_batch() and
bpf_iter_tcp_batch(), we have to deal with the possibility that our
batch size cannot contain all items in a bucket at once.
2) We cannot hold a lock on the bucket between iterations, meaning that
the structure can change in lots of interesting ways.

PROPOSAL
========
Can we achieve exactly-once semantics for socket iterators even in the
face of concurrent additions or removals to the current bucket? If we
ignore the possibility of signed 64 bit rollover, then yes. This
series replaces the current offset-based scheme used for progress
tracking with a scheme based on a monotonically increasing version
number. It works as follows:

* Assign index numbers on sockets in the bucket's linked list such that
they are monotonically increasing as you read from the head to tail.

* Every time a socket is added to a bucket, increment the hash
table's version number, ver.
* If the socket is being added to the head of the bucket's linked
list, set sk->idx to -1*ver.
* If the socket is being added to the tail of the bucket's linked
list, set sk->ver to ver.

Ex: append_head(C), append_head(B), append_tail(D), append_head(A),
append_tail(E) results in the following state.

A -> B -> C -> D -> E
-4 -2 -1 3 5
* As we iterate through a bucket, keep track of the last index number
we've seen for that bucket, iter->prev_idx.
* On subsequent iterations, skip ahead in the bucket until we see a
socket whose index, sk->idx, is greater than iter->prev_idx.

Since we always iterate from head to tail and indexes are always
increasing in that direction, we can be sure that any socket whose index
is greater than iter->prev_idx has not yet been seen. Any socket whose
index is less than or equal to iter->prev_idx has either been seen
before or was added since we last saw that bucket. In either case, it's
safe to skip them (any new sockets did not exist when we created the
iterator).

SOME ALTERNATIVES
=================
1. One alternative I considered was simply counting the number of
removals that have occurred per bucket, remembering this between
calls to bpf_iter_(tcp|udp)_batch() as part of the iterator state,
then using it to detect if it has changed. If any removals have
occurred, we would need to walk back iter->offset by at least that
much to avoid skips. This approach is simpler but may repeat sockets.
2. Don't allow partial batches; always make sure we capture all sockets
in a bucket in a batch. bpf_iter_(tcp|udp)_batch() already have some
logic to try one time to resize the batch, but as far as I know,
this isn't viable since we have to contend with the fact that
bpf_iter_(tcp|udp)_realloc_batch() may not be able to grab more
memory.

Anyway, maybe everyone already knows this can happen and isn't
overly concerned, since the possibility of skips or repeats is small,
but I thought I'd highlight the possibility just in case. It certainly
seems like something we'd want to avoid if we can help it, and with a
few adjustments, we can.

-Jordan

[1]: https://github.com/cilium/cilium/issues/37907
[2]: https://lore.kernel.org/bpf/[email protected]/
[3]: https://lore.kernel.org/netdev/[email protected]/
[4]: https://lore.kernel.org/netdev/[email protected]/

2 changes: 2 additions & 0 deletions include/net/inet_hashtables.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ struct inet_hashinfo {
struct inet_listen_hashbucket *lhash2;

bool pernet;

atomic64_t ver;
} ____cacheline_aligned_in_smp;

static inline struct inet_hashinfo *tcp_or_dccp_get_hashinfo(const struct sock *sk)
Expand Down
2 changes: 2 additions & 0 deletions include/net/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ struct sock_common {
u32 skc_window_clamp;
u32 skc_tw_snd_nxt; /* struct tcp_timewait_sock */
};
__s64 skc_idx;
/* public: */
};

Expand Down Expand Up @@ -378,6 +379,7 @@ struct sock {
#define sk_incoming_cpu __sk_common.skc_incoming_cpu
#define sk_flags __sk_common.skc_flags
#define sk_rxhash __sk_common.skc_rxhash
#define sk_idx __sk_common.skc_idx

__cacheline_group_begin(sock_write_rx);

Expand Down
3 changes: 2 additions & 1 deletion include/net/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -2202,7 +2202,8 @@ struct tcp_iter_state {
struct seq_net_private p;
enum tcp_seq_states state;
struct sock *syn_wait_sk;
int bucket, offset, sbucket, num;
int bucket, sbucket, num;
__s64 prev_idx;
loff_t last_pos;
};

Expand Down
1 change: 1 addition & 0 deletions include/net/udp.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ struct udp_table {
#endif
unsigned int mask;
unsigned int log;
atomic64_t ver;
};
extern struct udp_table udp_table;
void udp_table_init(struct udp_table *, const char *);
Expand Down
18 changes: 15 additions & 3 deletions net/ipv4/inet_hashtables.c
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,12 @@ struct sock *__inet_lookup_established(const struct net *net,
}
EXPORT_SYMBOL_GPL(__inet_lookup_established);

static inline __s64 inet_hashinfo_next_idx(struct inet_hashinfo *hinfo,
bool pos)
{
return (pos ? 1 : -1) * atomic64_inc_return(&hinfo->ver);
}

/* called with local bh disabled */
static int __inet_check_established(struct inet_timewait_death_row *death_row,
struct sock *sk, __u16 lport,
Expand Down Expand Up @@ -581,6 +587,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
sk->sk_hash = hash;
WARN_ON(!sk_unhashed(sk));
__sk_nulls_add_node_rcu(sk, &head->chain);
sk->sk_idx = inet_hashinfo_next_idx(hinfo, false);
if (tw) {
sk_nulls_del_node_init_rcu((struct sock *)tw);
__NET_INC_STATS(net, LINUX_MIB_TIMEWAITRECYCLED);
Expand Down Expand Up @@ -678,8 +685,10 @@ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk)
ret = false;
}

if (ret)
if (ret) {
__sk_nulls_add_node_rcu(sk, list);
sk->sk_idx = inet_hashinfo_next_idx(hashinfo, false);
}

spin_unlock(lock);

Expand Down Expand Up @@ -729,6 +738,7 @@ int __inet_hash(struct sock *sk, struct sock *osk)
{
struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk);
struct inet_listen_hashbucket *ilb2;
bool add_tail;
int err = 0;

if (sk->sk_state != TCP_LISTEN) {
Expand All @@ -747,11 +757,13 @@ int __inet_hash(struct sock *sk, struct sock *osk)
goto unlock;
}
sock_set_flag(sk, SOCK_RCU_FREE);
if (IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport &&
sk->sk_family == AF_INET6)
add_tail = IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport &&
sk->sk_family == AF_INET6;
if (add_tail)
__sk_nulls_add_node_tail_rcu(sk, &ilb2->nulls_head);
else
__sk_nulls_add_node_rcu(sk, &ilb2->nulls_head);
sk->sk_idx = inet_hashinfo_next_idx(hashinfo, add_tail);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
unlock:
spin_unlock(&ilb2->lock);
Expand Down
1 change: 1 addition & 0 deletions net/ipv4/tcp.c
Original file line number Diff line number Diff line change
Expand Up @@ -5147,6 +5147,7 @@ void __init tcp_init(void)

cnt = tcp_hashinfo.ehash_mask + 1;
sysctl_tcp_max_orphans = cnt / 2;
atomic64_set(&tcp_hashinfo.ver, 0);

tcp_init_mem();
/* Set per-socket limits to no more than 1/128 the pressure threshold */
Expand Down
28 changes: 15 additions & 13 deletions net/ipv4/tcp_ipv4.c
Original file line number Diff line number Diff line change
Expand Up @@ -2602,7 +2602,7 @@ static void *listening_get_first(struct seq_file *seq)
struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
struct tcp_iter_state *st = seq->private;

st->offset = 0;
st->prev_idx = 0;
for (; st->bucket <= hinfo->lhash2_mask; st->bucket++) {
struct inet_listen_hashbucket *ilb2;
struct hlist_nulls_node *node;
Expand Down Expand Up @@ -2637,7 +2637,7 @@ static void *listening_get_next(struct seq_file *seq, void *cur)
struct sock *sk = cur;

++st->num;
++st->offset;
st->prev_idx = sk->sk_idx;

sk = sk_nulls_next(sk);
sk_nulls_for_each_from(sk, node) {
Expand All @@ -2658,7 +2658,6 @@ static void *listening_get_idx(struct seq_file *seq, loff_t *pos)
void *rc;

st->bucket = 0;
st->offset = 0;
rc = listening_get_first(seq);

while (rc && *pos) {
Expand All @@ -2683,7 +2682,7 @@ static void *established_get_first(struct seq_file *seq)
struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
struct tcp_iter_state *st = seq->private;

st->offset = 0;
st->prev_idx = 0;
for (; st->bucket <= hinfo->ehash_mask; ++st->bucket) {
struct sock *sk;
struct hlist_nulls_node *node;
Expand Down Expand Up @@ -2714,7 +2713,6 @@ static void *established_get_next(struct seq_file *seq, void *cur)
struct sock *sk = cur;

++st->num;
++st->offset;

sk = sk_nulls_next(sk);

Expand Down Expand Up @@ -2763,8 +2761,8 @@ static void *tcp_seek_last_pos(struct seq_file *seq)
{
struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
struct tcp_iter_state *st = seq->private;
__s64 prev_idx = st->prev_idx;
int bucket = st->bucket;
int offset = st->offset;
int orig_num = st->num;
void *rc = NULL;

Expand All @@ -2773,18 +2771,21 @@ static void *tcp_seek_last_pos(struct seq_file *seq)
if (st->bucket > hinfo->lhash2_mask)
break;
rc = listening_get_first(seq);
while (offset-- && rc && bucket == st->bucket)
while (rc && bucket == st->bucket && prev_idx &&
((struct sock *)rc)->sk_idx <= prev_idx)
rc = listening_get_next(seq, rc);
if (rc)
break;
st->bucket = 0;
prev_idx = 0;
st->state = TCP_SEQ_STATE_ESTABLISHED;
fallthrough;
case TCP_SEQ_STATE_ESTABLISHED:
if (st->bucket > hinfo->ehash_mask)
break;
rc = established_get_first(seq);
while (offset-- && rc && bucket == st->bucket)
while (rc && bucket == st->bucket && prev_idx &&
((struct sock *)rc)->sk_idx <= prev_idx)
rc = established_get_next(seq, rc);
}

Expand All @@ -2807,7 +2808,7 @@ void *tcp_seq_start(struct seq_file *seq, loff_t *pos)
st->state = TCP_SEQ_STATE_LISTENING;
st->num = 0;
st->bucket = 0;
st->offset = 0;
st->prev_idx = 0;
rc = *pos ? tcp_get_idx(seq, *pos - 1) : SEQ_START_TOKEN;

out:
Expand All @@ -2832,7 +2833,7 @@ void *tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos)
if (!rc) {
st->state = TCP_SEQ_STATE_ESTABLISHED;
st->bucket = 0;
st->offset = 0;
st->prev_idx = 0;
rc = established_get_first(seq);
}
break;
Expand Down Expand Up @@ -3124,7 +3125,7 @@ static struct sock *bpf_iter_tcp_batch(struct seq_file *seq)
* it has to advance to the next bucket.
*/
if (iter->st_bucket_done) {
st->offset = 0;
st->prev_idx = 0;
st->bucket++;
if (st->state == TCP_SEQ_STATE_LISTENING &&
st->bucket > hinfo->lhash2_mask) {
Expand Down Expand Up @@ -3192,8 +3193,9 @@ static void *bpf_iter_tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos)
* the future start() will resume at st->offset in
* st->bucket. See tcp_seek_last_pos().
*/
st->offset++;
sock_gen_put(iter->batch[iter->cur_sk++]);
sk = iter->batch[iter->cur_sk++];
st->prev_idx = sk->sk_idx;
sock_gen_put(sk);
}

if (iter->cur_sk < iter->end_sk)
Expand Down
Loading
Loading