Skip to content

Commit

Permalink
net/wireguard: merge v1.0.20210219
Browse files Browse the repository at this point in the history
Signed-off-by: engstk <[email protected]>
  • Loading branch information
engstk committed Mar 2, 2021
1 parent 26233c1 commit b38a85f
Show file tree
Hide file tree
Showing 14 changed files with 207 additions and 121 deletions.
13 changes: 12 additions & 1 deletion net/wireguard/compat/compat-asm.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
#include <linux/kconfig.h>
#include <linux/version.h>

#ifdef RHEL_MAJOR
#if RHEL_MAJOR == 7
#define ISRHEL7
#elif RHEL_MAJOR == 8
#define ISRHEL8
#if RHEL_MINOR == 4
#define ISCENTOS8S
#endif
#endif
#endif

/* PaX compatibility */
#if defined(RAP_PLUGIN)
#undef ENTRY
Expand Down Expand Up @@ -40,7 +51,7 @@
#undef pull
#endif

#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 76)
#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 76) && !defined(ISCENTOS8S)
#define SYM_FUNC_START ENTRY
#define SYM_FUNC_END ENDPROC
#endif
Expand Down
51 changes: 38 additions & 13 deletions net/wireguard/compat/compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@
#ifdef RHEL_MAJOR
#if RHEL_MAJOR == 7
#define ISRHEL7
#if RHEL_MINOR == 8
#define ISCENTOS7
#endif
#elif RHEL_MAJOR == 8
#define ISRHEL8
#if RHEL_MINOR == 2
#define ISCENTOS8
#if RHEL_MINOR == 4
#define ISCENTOS8S
#endif
#endif
#endif
Expand Down Expand Up @@ -94,7 +91,7 @@

#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 17, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 16, 83)
#define ipv6_dst_lookup_flow(a, b, c, d) ipv6_dst_lookup_flow(b, c, d)
#elif (LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 5) && LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(5, 3, 18) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 20, 0) && !defined(ISUBUNTU1904)) || (!defined(ISRHEL8) && !defined(ISDEBIAN) && !defined(ISUBUNTU1804) && LINUX_VERSION_CODE < KERNEL_VERSION(4, 19, 119) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 15, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 14, 181) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 10, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 9, 224) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 5, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 4, 224) && !defined(ISUBUNTU1604) && (!defined(ISRHEL7) || defined(ISCENTOS7)))
#elif (LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 5) && LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(5, 3, 18) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 20, 0) && !defined(ISUBUNTU1904)) || (!defined(ISRHEL8) && !defined(ISDEBIAN) && !defined(ISUBUNTU1804) && LINUX_VERSION_CODE < KERNEL_VERSION(4, 19, 119) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 15, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 14, 181) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 10, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 9, 224) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 5, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 4, 224) && !defined(ISUBUNTU1604) && !defined(ISRHEL7))
#define ipv6_dst_lookup_flow(a, b, c, d) ipv6_dst_lookup(a, b, &dst, c) + (void *)0 ?: dst
#endif

Expand Down Expand Up @@ -760,7 +757,7 @@ static inline void crypto_xor_cpy(u8 *dst, const u8 *src1, const u8 *src2,
#define hlist_add_behind(a, b) hlist_add_after(b, a)
#endif

#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 0, 0)
#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 0, 0) && !defined(ISCENTOS8S)
#define totalram_pages() totalram_pages
#endif

Expand Down Expand Up @@ -826,7 +823,7 @@ static __always_inline void old_rcu_barrier(void)
#define COMPAT_CANNOT_DEPRECIATE_BH_RCU
#endif

#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 19, 10) && !defined(ISRHEL8)
#if (LINUX_VERSION_CODE < KERNEL_VERSION(4, 19, 10) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 15, 0) && !defined(ISRHEL8) && !defined(ISUBUNTU1804)) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 14, 217)
static inline void skb_mark_not_on_list(struct sk_buff *skb)
{
skb->next = NULL;
Expand Down Expand Up @@ -936,11 +933,11 @@ static inline int skb_ensure_writable(struct sk_buff *skb, int write_len)
#endif

#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 6, 0)
#include <linux/icmpv6.h>
#include <net/icmp.h>
#if IS_ENABLED(CONFIG_NF_NAT)
#include <linux/ip.h>
#include <linux/icmpv6.h>
#include <net/ipv6.h>
#include <net/icmp.h>
#include <net/netfilter/nf_conntrack.h>
#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 1, 0) && !defined(ISRHEL8)
#include <net/netfilter/nf_nat_core.h>
Expand All @@ -954,6 +951,7 @@ static inline void __compat_icmp_ndo_send(struct sk_buff *skb_in, int type, int

ct = nf_ct_get(skb_in, &ctinfo);
if (!ct || !(ct->status & IPS_SRC_NAT)) {
memset(skb_in->cb, 0, sizeof(skb_in->cb));
icmp_send(skb_in, type, code, info);
return;
}
Expand All @@ -969,6 +967,7 @@ static inline void __compat_icmp_ndo_send(struct sk_buff *skb_in, int type, int

orig_ip = ip_hdr(skb_in)->saddr;
ip_hdr(skb_in)->saddr = ct->tuplehash[0].tuple.src.u3.ip;
memset(skb_in->cb, 0, sizeof(skb_in->cb));
icmp_send(skb_in, type, code, info);
ip_hdr(skb_in)->saddr = orig_ip;
out:
Expand All @@ -983,6 +982,7 @@ static inline void __compat_icmpv6_ndo_send(struct sk_buff *skb_in, u8 type, u8

ct = nf_ct_get(skb_in, &ctinfo);
if (!ct || !(ct->status & IPS_SRC_NAT)) {
memset(skb_in->cb, 0, sizeof(skb_in->cb));
icmpv6_send(skb_in, type, code, info);
return;
}
Expand All @@ -998,14 +998,23 @@ static inline void __compat_icmpv6_ndo_send(struct sk_buff *skb_in, u8 type, u8

orig_ip = ipv6_hdr(skb_in)->saddr;
ipv6_hdr(skb_in)->saddr = ct->tuplehash[0].tuple.src.u3.in6;
memset(skb_in->cb, 0, sizeof(skb_in->cb));
icmpv6_send(skb_in, type, code, info);
ipv6_hdr(skb_in)->saddr = orig_ip;
out:
consume_skb(cloned_skb);
}
#else
#define __compat_icmp_ndo_send icmp_send
#define __compat_icmpv6_ndo_send icmpv6_send
static inline void __compat_icmp_ndo_send(struct sk_buff *skb_in, int type, int code, __be32 info)
{
memset(skb_in->cb, 0, sizeof(skb_in->cb));
icmp_send(skb_in, type, code, info);
}
static inline void __compat_icmpv6_ndo_send(struct sk_buff *skb_in, u8 type, u8 code, __u32 info)
{
memset(skb_in->cb, 0, sizeof(skb_in->cb));
icmpv6_send(skb_in, type, code, info);
}
#endif
#define icmp_ndo_send __compat_icmp_ndo_send
#define icmpv6_ndo_send __compat_icmpv6_ndo_send
Expand All @@ -1015,7 +1024,7 @@ static inline void __compat_icmpv6_ndo_send(struct sk_buff *skb_in, u8 type, u8
#define COMPAT_CANNOT_USE_MAX_MTU
#endif

#if (LINUX_VERSION_CODE < KERNEL_VERSION(5, 5, 14) && LINUX_VERSION_CODE >= KERNEL_VERSION(5, 5, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 29) && !defined(ISUBUNTU1910) && !defined(ISUBUNTU1904) && (!defined(ISRHEL8) || defined(ISCENTOS8)))
#if (LINUX_VERSION_CODE < KERNEL_VERSION(5, 5, 14) && LINUX_VERSION_CODE >= KERNEL_VERSION(5, 5, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 29) && !defined(ISUBUNTU1910) && !defined(ISUBUNTU1904) && !defined(ISRHEL8))
#include <linux/skbuff.h>
#include <net/sch_generic.h>
static inline void skb_reset_redirect(struct sk_buff *skb)
Expand Down Expand Up @@ -1071,6 +1080,22 @@ static const struct header_ops ip_tunnel_header_ops = { .parse_protocol = ip_tun
#define kfree_sensitive(a) kzfree(a)
#endif

#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 3, 0) && !defined(ISRHEL7)
#define xchg_release xchg
#endif

#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 14, 0) && !defined(ISRHEL7)
#include <asm/barrier.h>
#ifndef smp_load_acquire
#define smp_load_acquire(p) \
({ \
typeof(*p) ___p1 = ACCESS_ONCE(*p); \
smp_mb(); \
___p1; \
})
#endif
#endif

#if defined(ISUBUNTU1604) || defined(ISRHEL7)
#include <linux/siphash.h>
#ifndef _WG_LINUX_SIPHASH_H
Expand Down
3 changes: 1 addition & 2 deletions net/wireguard/compat/simd/include/linux/simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <linux/sched.h>
#include <asm/simd.h>
#if defined(CONFIG_X86_64)
#include <linux/version.h>
#include <asm/fpu/api.h>
#elif defined(CONFIG_KERNEL_MODE_NEON)
#include <asm/neon.h>
Expand All @@ -25,7 +24,7 @@ typedef enum {

static inline void simd_get(simd_context_t *ctx)
{
*ctx = !IS_ENABLED(CONFIG_PREEMPT_RT_BASE) && may_use_simd() ? HAVE_FULL_SIMD : HAVE_NO_SIMD;
*ctx = !IS_ENABLED(CONFIG_PREEMPT_RT) && !IS_ENABLED(CONFIG_PREEMPT_RT_BASE) && may_use_simd() ? HAVE_FULL_SIMD : HAVE_NO_SIMD;
}

static inline void simd_put(simd_context_t *ctx)
Expand Down
1 change: 0 additions & 1 deletion net/wireguard/crypto/zinc/curve25519/curve25519.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "../selftest/run.h"

#include <asm/unaligned.h>
#include <linux/version.h>
#include <linux/string.h>
#include <linux/random.h>
#include <linux/module.h>
Expand Down
21 changes: 11 additions & 10 deletions net/wireguard/device.c
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev)
else if (skb->protocol == htons(ETH_P_IPV6))
net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI6\n",
dev->name, &ipv6_hdr(skb)->daddr);
goto err;
goto err_icmp;
}

family = READ_ONCE(peer->endpoint.addr.sa_family);
Expand All @@ -165,7 +165,7 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev)
} else {
struct sk_buff *segs = skb_gso_segment(skb, 0);

if (unlikely(IS_ERR(segs))) {
if (IS_ERR(segs)) {
ret = PTR_ERR(segs);
goto err_peer;
}
Expand Down Expand Up @@ -209,12 +209,13 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev)

err_peer:
wg_peer_put(peer);
err:
++dev->stats.tx_errors;
err_icmp:
if (skb->protocol == htons(ETH_P_IP))
icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
else if (skb->protocol == htons(ETH_P_IPV6))
icmpv6_ndo_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, 0);
err:
++dev->stats.tx_errors;
kfree_skb(skb);
return ret;
}
Expand Down Expand Up @@ -242,8 +243,8 @@ static void wg_destruct(struct net_device *dev)
destroy_workqueue(wg->handshake_receive_wq);
destroy_workqueue(wg->handshake_send_wq);
destroy_workqueue(wg->packet_crypt_wq);
wg_packet_queue_free(&wg->decrypt_queue, true);
wg_packet_queue_free(&wg->encrypt_queue, true);
wg_packet_queue_free(&wg->decrypt_queue);
wg_packet_queue_free(&wg->encrypt_queue);
rcu_barrier(); /* Wait for all the peers to be actually freed. */
wg_ratelimiter_uninit();
memzero_explicit(&wg->static_identity, sizeof(wg->static_identity));
Expand Down Expand Up @@ -351,12 +352,12 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
goto err_destroy_handshake_send;

ret = wg_packet_queue_init(&wg->encrypt_queue, wg_packet_encrypt_worker,
true, MAX_QUEUED_PACKETS);
MAX_QUEUED_PACKETS);
if (ret < 0)
goto err_destroy_packet_crypt;

ret = wg_packet_queue_init(&wg->decrypt_queue, wg_packet_decrypt_worker,
true, MAX_QUEUED_PACKETS);
MAX_QUEUED_PACKETS);
if (ret < 0)
goto err_free_encrypt_queue;

Expand All @@ -381,9 +382,9 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
err_uninit_ratelimiter:
wg_ratelimiter_uninit();
err_free_decrypt_queue:
wg_packet_queue_free(&wg->decrypt_queue, true);
wg_packet_queue_free(&wg->decrypt_queue);
err_free_encrypt_queue:
wg_packet_queue_free(&wg->encrypt_queue, true);
wg_packet_queue_free(&wg->encrypt_queue);
err_destroy_packet_crypt:
destroy_workqueue(wg->packet_crypt_wq);
err_destroy_handshake_send:
Expand Down
15 changes: 8 additions & 7 deletions net/wireguard/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ struct multicore_worker {

struct crypt_queue {
struct ptr_ring ring;
union {
struct {
struct multicore_worker __percpu *worker;
int last_cpu;
};
struct work_struct work;
};
struct multicore_worker __percpu *worker;
int last_cpu;
};

struct prev_queue {
struct sk_buff *head, *tail, *peeked;
struct { struct sk_buff *next, *prev; } empty; // Match first 2 members of struct sk_buff.
atomic_t count;
};

struct wg_device {
Expand Down
28 changes: 9 additions & 19 deletions net/wireguard/peer.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,22 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
peer = kzalloc(sizeof(*peer), GFP_KERNEL);
if (unlikely(!peer))
return ERR_PTR(ret);
peer->device = wg;
if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL))
goto err;

peer->device = wg;
wg_noise_handshake_init(&peer->handshake, &wg->static_identity,
public_key, preshared_key, peer);
if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL))
goto err_1;
if (wg_packet_queue_init(&peer->tx_queue, wg_packet_tx_worker, false,
MAX_QUEUED_PACKETS))
goto err_2;
if (wg_packet_queue_init(&peer->rx_queue, NULL, false,
MAX_QUEUED_PACKETS))
goto err_3;

peer->internal_id = atomic64_inc_return(&peer_counter);
peer->serial_work_cpu = nr_cpumask_bits;
wg_cookie_init(&peer->latest_cookie);
wg_timers_init(peer);
wg_cookie_checker_precompute_peer_keys(peer);
spin_lock_init(&peer->keypairs.keypair_update_lock);
INIT_WORK(&peer->transmit_handshake_work,
wg_packet_handshake_send_worker);
INIT_WORK(&peer->transmit_handshake_work, wg_packet_handshake_send_worker);
INIT_WORK(&peer->transmit_packet_work, wg_packet_tx_worker);
wg_prev_queue_init(&peer->tx_queue);
wg_prev_queue_init(&peer->rx_queue);
rwlock_init(&peer->endpoint_lock);
kref_init(&peer->refcount);
skb_queue_head_init(&peer->staged_packet_queue);
Expand All @@ -68,11 +63,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
pr_debug("%s: Peer %llu created\n", wg->dev->name, peer->internal_id);
return peer;

err_3:
wg_packet_queue_free(&peer->tx_queue, false);
err_2:
dst_cache_destroy(&peer->endpoint_cache);
err_1:
err:
kfree(peer);
return ERR_PTR(ret);
}
Expand Down Expand Up @@ -197,8 +188,7 @@ static void rcu_release(struct rcu_head *rcu)
struct wg_peer *peer = container_of(rcu, struct wg_peer, rcu);

dst_cache_destroy(&peer->endpoint_cache);
wg_packet_queue_free(&peer->rx_queue, false);
wg_packet_queue_free(&peer->tx_queue, false);
WARN_ON(wg_prev_queue_peek(&peer->tx_queue) || wg_prev_queue_peek(&peer->rx_queue));

/* The final zeroing takes care of clearing any remaining handshake key
* material and other potentially sensitive information.
Expand Down
8 changes: 4 additions & 4 deletions net/wireguard/peer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ struct endpoint {

struct wg_peer {
struct wg_device *device;
struct crypt_queue tx_queue, rx_queue;
struct prev_queue tx_queue, rx_queue;
struct sk_buff_head staged_packet_queue;
int serial_work_cpu;
bool is_dead;
struct noise_keypairs keypairs;
struct endpoint endpoint;
struct dst_cache endpoint_cache;
rwlock_t endpoint_lock;
struct noise_handshake handshake;
atomic64_t last_sent_handshake;
struct work_struct transmit_handshake_work, clear_peer_work;
struct work_struct transmit_handshake_work, clear_peer_work, transmit_packet_work;
struct cookie latest_cookie;
struct hlist_node pubkey_hash;
u64 rx_bytes, tx_bytes;
Expand All @@ -61,9 +62,8 @@ struct wg_peer {
struct rcu_head rcu;
struct list_head peer_list;
struct list_head allowedips_list;
u64 internal_id;
struct napi_struct napi;
bool is_dead;
u64 internal_id;
};

struct wg_peer *wg_peer_create(struct wg_device *wg,
Expand Down
Loading

0 comments on commit b38a85f

Please sign in to comment.