Skip to content

Commit

Permalink
Improve performance with socket reuse and other optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
weiiwang01 committed Sep 30, 2023
1 parent 9df2c13 commit 145a567
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 55 deletions.
22 changes: 9 additions & 13 deletions internal/analyzer/analyzer.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
package analyzer

import (
"context"
"encoding/binary"
"fmt"
"github.com/weiiwang01/wpex/internal/exchange"
"log"
"log/slog"
"net"
"time"
)

const (
cookieSize = 16
macSize = 16
handshakeInitiationSize = 148
handshakeResponseSize = 92
Expand Down Expand Up @@ -120,29 +119,27 @@ func (t *WireguardAnalyzer) analyseCookieReply(packet []byte, peer net.UDPAddr)
}

func (t *WireguardAnalyzer) analyseTransportData(packet []byte, peer net.UDPAddr) ([]net.UDPAddr, []byte) {
logger := slog.With("addr", peer.String())
receiverIdx := t.decodeIndex(packet[4:8])
receiver, err := t.table.GetPeerAddr(receiverIdx)
slog.Log(context.TODO(), slog.LevelDebug-4, "transport data message received", "addr", peer.String(), "receiver", receiverIdx, "forward", receiver.String())
if err != nil {
logger.Warn(fmt.Sprintf("unknown receiver in transport data: %s", err))
slog.Warn(fmt.Sprintf("unknown receiver in transport data: %s", err), "addr", peer.String())
return nil, nil
}
sender, err := t.table.GetPeerCounterpart(receiverIdx)
if err != nil {
logger.Warn(fmt.Sprintf("unknown sender in transport data: %s", err))
slog.Warn(fmt.Sprintf("unknown sender in transport data: %s", err), "addr", peer.String())
return nil, nil
}
addr, err := t.table.GetPeerAddr(sender)
if err != nil {
logger.Warn(fmt.Sprintf("no sender address record in transport data: %s", err))
slog.Warn(fmt.Sprintf("no sender address record in transport data: %s", err), "addr", peer.String())
return nil, nil
}
if addr.String() != peer.String() {
if !addrEqual(addr, peer) {
slog.Debug("roaming detected in transport data message", "sender", sender, "before", addr.String(), "after", peer.String())
err := t.table.UpdatePeerAddr(sender, peer)
if err != nil {
logger.Warn(fmt.Sprintf("failed to update sender address: %s", err))
slog.Warn(fmt.Sprintf("failed to update sender address: %s", err), "addr", peer.String())
return nil, nil
}
}
Expand All @@ -151,15 +148,14 @@ func (t *WireguardAnalyzer) analyseTransportData(packet []byte, peer net.UDPAddr

// Analyse updates the exchange table with the source address and returns the forwarding address for this packet.
func (t *WireguardAnalyzer) Analyse(packet []byte, peer net.UDPAddr) ([]net.UDPAddr, []byte) {
logger := slog.With("addr", peer.String())
const (
handshakeInitiationType = iota + 1
handshakeResponseType
cookieReplyType
transportDataType
)
if len(packet) < 16 {
logger.Error("invalid wireguard message: too short")
slog.Error("invalid wireguard message: too short", "addr", peer.String())
return nil, nil
}
msgType := int(binary.LittleEndian.Uint32(packet[:4]))
Expand All @@ -173,15 +169,15 @@ func (t *WireguardAnalyzer) Analyse(packet []byte, peer net.UDPAddr) ([]net.UDPA
case transportDataType:
return t.analyseTransportData(packet, peer)
default:
logger.Error("unknown message type")
slog.Error("unknown message type", "addr", peer.String())
return nil, nil
}
}

func MakeWireguardAnalyzer(pubkeys [][]byte) WireguardAnalyzer {
secret, err := token(32)
if err != nil {
panic(fmt.Errorf("failed to generate cookie secret: %w", err))
log.Fatal(fmt.Errorf("failed to generate cookie secret: %w", err))
}
return WireguardAnalyzer{
table: exchange.MakeExchangeTable(),
Expand Down
6 changes: 6 additions & 0 deletions internal/analyzer/helper.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package analyzer

import (
"bytes"
"crypto/rand"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"net"
)

func hash(dst []byte, data ...[]byte) [32]byte {
Expand Down Expand Up @@ -43,3 +45,7 @@ func token(n int) ([]byte, error) {
}
return b, nil
}

func addrEqual(a1, a2 net.UDPAddr) bool {
return a1.Port == a2.Port && bytes.Equal(a1.IP, a2.IP) && a1.Zone == a2.Zone
}
64 changes: 38 additions & 26 deletions internal/relay/relay.go
Original file line number Diff line number Diff line change
@@ -1,67 +1,79 @@
package relay

import (
"context"
"fmt"
"github.com/weiiwang01/wpex/internal/analyzer"
"golang.org/x/time/rate"
"log"
"log/slog"
"net"
"runtime"
"syscall"
)

type udpPacket struct {
addr net.UDPAddr
addr net.Addr
data []byte
source net.UDPAddr
source net.Addr
isBroadcast bool
}

type Relay struct {
send chan udpPacket
analyzer analyzer.WireguardAnalyzer
conn *net.UDPConn
limit *rate.Limiter
}

func (r *Relay) sendUDP() {
for packet := range r.send {
if packet.isBroadcast {
if !r.limit.Allow() {
slog.Warn("broadcast rate limit exceeded", "src", packet.source.String(), "dst", packet.addr.String())
}
}
_, err := r.conn.WriteToUDP(packet.data, &packet.addr)
if err != nil {
slog.Error("error while sending UDP packet", "error", err.Error(), "addr", packet.addr.String())
}
}
}

func (r *Relay) receiveUDP() {
func (r *Relay) relay(conn *net.UDPConn) {
buf := make([]byte, 65536)
for {
buf := make([]byte, 1500)
n, remoteAddr, err := r.conn.ReadFromUDP(buf)
n, remoteAddr, err := conn.ReadFromUDP(buf)
if err != nil {
slog.Error("error while receiving UDP packet", "error", err.Error(), "addr", remoteAddr)
continue
}
packet := buf[:n]
peers, send := r.analyzer.Analyse(packet, *remoteAddr)
for _, peer := range peers {
r.send <- udpPacket{addr: peer, data: send, source: *remoteAddr, isBroadcast: len(peers) > 1}
if len(peers) > 1 {
if !r.limit.Allow() {
slog.Warn("broadcast rate limit exceeded", "src", remoteAddr.String(), "dst", peer.String())
continue
}
}
_, err := conn.WriteToUDP(send, &peer)
if err != nil {
slog.Error("error while sending UDP packet", "error", err.Error(), "addr", peer.String())
}
}
}
}

// Start starts the wireguard packet relay server.
func Start(conn *net.UDPConn, publicKeys [][]byte, broadcastLimit *rate.Limiter) {
func Start(address string, publicKeys [][]byte, broadcastLimit *rate.Limiter) {
slog.Info("server listening", "addr", address)
var lc = net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
var opErr error
if err := c.Control(func(fd uintptr) { opErr = control(fd) }); err != nil {
return err
}
return opErr
},
}
relay := Relay{
send: make(chan udpPacket),
analyzer: analyzer.MakeWireguardAnalyzer(publicKeys),
conn: conn,
limit: broadcastLimit,
}
for i := 0; i < 4; i++ {
go relay.sendUDP()
go relay.receiveUDP()
for i := 0; i < runtime.NumCPU(); i++ {
l, err := lc.ListenPacket(context.Background(), "udp", address)
if err != nil {
log.Fatal(fmt.Sprintf("failed to listen on %s: %s", address, err))
}
conn := l.(*net.UDPConn)
go relay.relay(conn)
}
select {}
}
9 changes: 9 additions & 0 deletions internal/relay/unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//go:build linux || darwin

package relay

import "golang.org/x/sys/unix"

func control(fd uintptr) error {
return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
}
9 changes: 9 additions & 0 deletions internal/relay/windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//go:build windows

package relay

import "golang.org/x/sys/windows"

func control(fd uintptr) error {
return windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_REUSEADDR, 1)
}
19 changes: 3 additions & 16 deletions wpex.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"fmt"
"github.com/weiiwang01/wpex/internal/relay"
"golang.org/x/time/rate"
"log"
"log/slog"
"net"
"os"
"strings"
)
Expand All @@ -29,7 +29,6 @@ func main() {
bind := flag.String("bind", "", "address to bind to")
port := flag.Uint("port", 40000, "port number to listen")
debug := flag.Bool("debug", false, "enable debug messages")
trace := flag.Bool("trace", false, "enable trace level debug messages")
broadcastRate := flag.Uint("broadcast-rate", 0, "broadcast rate limit in packet per second")
versionFlag := flag.Bool("version", false, "show version number and quit")
var allows pubKeys
Expand All @@ -44,35 +43,23 @@ func main() {
if *debug {
loggingLevel.Set(slog.LevelDebug)
}
if *trace {
loggingLevel.Set(slog.LevelDebug - 4)
}
slog.SetDefault(logger)
address := fmt.Sprintf("%s:%d", *bind, *port)
var allowKeys [][]byte
for _, allow := range allows {
k, err := base64.StdEncoding.DecodeString(allow)
if err != nil || len(k) != 32 {
panic(fmt.Sprintf("invalid wireguard public key: '%s'", allow))
log.Fatal(fmt.Sprintf("invalid wireguard public key: '%s'", allow))
}
logger.Debug("allow wireguard public key", "key", allow)
allowKeys = append(allowKeys, k)
}
addr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
panic(fmt.Sprintf("failed to resolve UDP address: %s", err))
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
panic(fmt.Sprintf("failed to listen on UDP: %s", err))
}
logger.Info("server listening", "addr", address)
limit := rate.Limit(*broadcastRate)
if *broadcastRate == 0 {
slog.Debug("broadcast rate limit is set to +Inf")
limit = rate.Inf
} else {
slog.Debug(fmt.Sprintf("broadcast rate limit is set to %d", *broadcastRate))
}
relay.Start(conn, allowKeys, rate.NewLimiter(limit, int((*broadcastRate)*5)))
relay.Start(address, allowKeys, rate.NewLimiter(limit, int((*broadcastRate)*5)))
}

0 comments on commit 145a567

Please sign in to comment.