diff --git a/internal/analyzer/analyzer.go b/internal/analyzer/analyzer.go index ce5a7b8..ad4b95b 100644 --- a/internal/analyzer/analyzer.go +++ b/internal/analyzer/analyzer.go @@ -1,17 +1,16 @@ package analyzer import ( - "context" "encoding/binary" "fmt" "github.com/weiiwang01/wpex/internal/exchange" + "log" "log/slog" "net" "time" ) const ( - cookieSize = 16 macSize = 16 handshakeInitiationSize = 148 handshakeResponseSize = 92 @@ -120,29 +119,27 @@ func (t *WireguardAnalyzer) analyseCookieReply(packet []byte, peer net.UDPAddr) } func (t *WireguardAnalyzer) analyseTransportData(packet []byte, peer net.UDPAddr) ([]net.UDPAddr, []byte) { - logger := slog.With("addr", peer.String()) receiverIdx := t.decodeIndex(packet[4:8]) receiver, err := t.table.GetPeerAddr(receiverIdx) - slog.Log(context.TODO(), slog.LevelDebug-4, "transport data message received", "addr", peer.String(), "receiver", receiverIdx, "forward", receiver.String()) if err != nil { - logger.Warn(fmt.Sprintf("unknown receiver in transport data: %s", err)) + slog.Warn(fmt.Sprintf("unknown receiver in transport data: %s", err), "addr", peer.String()) return nil, nil } sender, err := t.table.GetPeerCounterpart(receiverIdx) if err != nil { - logger.Warn(fmt.Sprintf("unknown sender in transport data: %s", err)) + slog.Warn(fmt.Sprintf("unknown sender in transport data: %s", err), "addr", peer.String()) return nil, nil } addr, err := t.table.GetPeerAddr(sender) if err != nil { - logger.Warn(fmt.Sprintf("no sender address record in transport data: %s", err)) + slog.Warn(fmt.Sprintf("no sender address record in transport data: %s", err), "addr", peer.String()) return nil, nil } - if addr.String() != peer.String() { + if !addrEqual(addr, peer) { slog.Debug("roaming detected in transport data message", "sender", sender, "before", addr.String(), "after", peer.String()) err := t.table.UpdatePeerAddr(sender, peer) if err != nil { - logger.Warn(fmt.Sprintf("failed to update sender address: %s", err)) + slog.Warn(fmt.Sprintf("failed to update sender address: %s", err), "addr", peer.String()) return nil, nil } } @@ -151,7 +148,6 @@ func (t *WireguardAnalyzer) analyseTransportData(packet []byte, peer net.UDPAddr // Analyse updates the exchange table with the source address and returns the forwarding address for this packet. func (t *WireguardAnalyzer) Analyse(packet []byte, peer net.UDPAddr) ([]net.UDPAddr, []byte) { - logger := slog.With("addr", peer.String()) const ( handshakeInitiationType = iota + 1 handshakeResponseType @@ -159,7 +155,7 @@ func (t *WireguardAnalyzer) Analyse(packet []byte, peer net.UDPAddr) ([]net.UDPA transportDataType ) if len(packet) < 16 { - logger.Error("invalid wireguard message: too short") + slog.Error("invalid wireguard message: too short", "addr", peer.String()) return nil, nil } msgType := int(binary.LittleEndian.Uint32(packet[:4])) @@ -173,7 +169,7 @@ func (t *WireguardAnalyzer) Analyse(packet []byte, peer net.UDPAddr) ([]net.UDPA case transportDataType: return t.analyseTransportData(packet, peer) default: - logger.Error("unknown message type") + slog.Error("unknown message type", "addr", peer.String()) return nil, nil } } @@ -181,7 +177,7 @@ func (t *WireguardAnalyzer) Analyse(packet []byte, peer net.UDPAddr) ([]net.UDPA func MakeWireguardAnalyzer(pubkeys [][]byte) WireguardAnalyzer { secret, err := token(32) if err != nil { - panic(fmt.Errorf("failed to generate cookie secret: %w", err)) + log.Fatal(fmt.Errorf("failed to generate cookie secret: %w", err)) } return WireguardAnalyzer{ table: exchange.MakeExchangeTable(), diff --git a/internal/analyzer/helper.go b/internal/analyzer/helper.go index 7de10e1..62f0229 100644 --- a/internal/analyzer/helper.go +++ b/internal/analyzer/helper.go @@ -1,9 +1,11 @@ package analyzer import ( + "bytes" "crypto/rand" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" + "net" ) func hash(dst []byte, data ...[]byte) [32]byte { @@ -43,3 +45,7 @@ func token(n int) ([]byte, error) { } return b, nil } + +func addrEqual(a1, a2 net.UDPAddr) bool { + return a1.Port == a2.Port && bytes.Equal(a1.IP, a2.IP) && a1.Zone == a2.Zone +} diff --git a/internal/relay/relay.go b/internal/relay/relay.go index cce8b09..d91a697 100644 --- a/internal/relay/relay.go +++ b/internal/relay/relay.go @@ -1,44 +1,34 @@ package relay import ( + "context" + "fmt" "github.com/weiiwang01/wpex/internal/analyzer" "golang.org/x/time/rate" + "log" "log/slog" "net" + "runtime" + "syscall" ) type udpPacket struct { - addr net.UDPAddr + addr net.Addr data []byte - source net.UDPAddr + source net.Addr isBroadcast bool } type Relay struct { send chan udpPacket analyzer analyzer.WireguardAnalyzer - conn *net.UDPConn limit *rate.Limiter } -func (r *Relay) sendUDP() { - for packet := range r.send { - if packet.isBroadcast { - if !r.limit.Allow() { - slog.Warn("broadcast rate limit exceeded", "src", packet.source.String(), "dst", packet.addr.String()) - } - } - _, err := r.conn.WriteToUDP(packet.data, &packet.addr) - if err != nil { - slog.Error("error while sending UDP packet", "error", err.Error(), "addr", packet.addr.String()) - } - } -} - -func (r *Relay) receiveUDP() { +func (r *Relay) relay(conn *net.UDPConn) { + buf := make([]byte, 65536) for { - buf := make([]byte, 1500) - n, remoteAddr, err := r.conn.ReadFromUDP(buf) + n, remoteAddr, err := conn.ReadFromUDP(buf) if err != nil { slog.Error("error while receiving UDP packet", "error", err.Error(), "addr", remoteAddr) continue @@ -46,22 +36,44 @@ func (r *Relay) receiveUDP() { packet := buf[:n] peers, send := r.analyzer.Analyse(packet, *remoteAddr) for _, peer := range peers { - r.send <- udpPacket{addr: peer, data: send, source: *remoteAddr, isBroadcast: len(peers) > 1} + if len(peers) > 1 { + if !r.limit.Allow() { + slog.Warn("broadcast rate limit exceeded", "src", remoteAddr.String(), "dst", peer.String()) + continue + } + } + _, err := conn.WriteToUDP(send, &peer) + if err != nil { + slog.Error("error while sending UDP packet", "error", err.Error(), "addr", peer.String()) + } } } } // Start starts the wireguard packet relay server. -func Start(conn *net.UDPConn, publicKeys [][]byte, broadcastLimit *rate.Limiter) { +func Start(address string, publicKeys [][]byte, broadcastLimit *rate.Limiter) { + slog.Info("server listening", "addr", address) + var lc = net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + var opErr error + if err := c.Control(func(fd uintptr) { opErr = control(fd) }); err != nil { + return err + } + return opErr + }, + } relay := Relay{ send: make(chan udpPacket), analyzer: analyzer.MakeWireguardAnalyzer(publicKeys), - conn: conn, limit: broadcastLimit, } - for i := 0; i < 4; i++ { - go relay.sendUDP() - go relay.receiveUDP() + for i := 0; i < runtime.NumCPU(); i++ { + l, err := lc.ListenPacket(context.Background(), "udp", address) + if err != nil { + log.Fatal(fmt.Sprintf("failed to listen on %s: %s", address, err)) + } + conn := l.(*net.UDPConn) + go relay.relay(conn) } select {} } diff --git a/internal/relay/unix.go b/internal/relay/unix.go new file mode 100644 index 0000000..eed9cfc --- /dev/null +++ b/internal/relay/unix.go @@ -0,0 +1,9 @@ +//go:build linux || darwin + +package relay + +import "golang.org/x/sys/unix" + +func control(fd uintptr) error { + return unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) +} diff --git a/internal/relay/windows.go b/internal/relay/windows.go new file mode 100644 index 0000000..a83c5d9 --- /dev/null +++ b/internal/relay/windows.go @@ -0,0 +1,9 @@ +//go:build windows + +package relay + +import "golang.org/x/sys/windows" + +func control(fd uintptr) error { + return windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_REUSEADDR, 1) +} diff --git a/wpex.go b/wpex.go index 4578a2a..b50819e 100644 --- a/wpex.go +++ b/wpex.go @@ -6,8 +6,8 @@ import ( "fmt" "github.com/weiiwang01/wpex/internal/relay" "golang.org/x/time/rate" + "log" "log/slog" - "net" "os" "strings" ) @@ -29,7 +29,6 @@ func main() { bind := flag.String("bind", "", "address to bind to") port := flag.Uint("port", 40000, "port number to listen") debug := flag.Bool("debug", false, "enable debug messages") - trace := flag.Bool("trace", false, "enable trace level debug messages") broadcastRate := flag.Uint("broadcast-rate", 0, "broadcast rate limit in packet per second") versionFlag := flag.Bool("version", false, "show version number and quit") var allows pubKeys @@ -44,29 +43,17 @@ func main() { if *debug { loggingLevel.Set(slog.LevelDebug) } - if *trace { - loggingLevel.Set(slog.LevelDebug - 4) - } slog.SetDefault(logger) address := fmt.Sprintf("%s:%d", *bind, *port) var allowKeys [][]byte for _, allow := range allows { k, err := base64.StdEncoding.DecodeString(allow) if err != nil || len(k) != 32 { - panic(fmt.Sprintf("invalid wireguard public key: '%s'", allow)) + log.Fatal(fmt.Sprintf("invalid wireguard public key: '%s'", allow)) } logger.Debug("allow wireguard public key", "key", allow) allowKeys = append(allowKeys, k) } - addr, err := net.ResolveUDPAddr("udp", address) - if err != nil { - panic(fmt.Sprintf("failed to resolve UDP address: %s", err)) - } - conn, err := net.ListenUDP("udp", addr) - if err != nil { - panic(fmt.Sprintf("failed to listen on UDP: %s", err)) - } - logger.Info("server listening", "addr", address) limit := rate.Limit(*broadcastRate) if *broadcastRate == 0 { slog.Debug("broadcast rate limit is set to +Inf") @@ -74,5 +61,5 @@ func main() { } else { slog.Debug(fmt.Sprintf("broadcast rate limit is set to %d", *broadcastRate)) } - relay.Start(conn, allowKeys, rate.NewLimiter(limit, int((*broadcastRate)*5))) + relay.Start(address, allowKeys, rate.NewLimiter(limit, int((*broadcastRate)*5))) }