diff --git a/go.mod b/go.mod index b22e84ff..cb207d0f 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/shadowsocks/go-shadowsocks2 v0.1.4-0.20201002022019-75d43273f5a5 github.com/stretchr/testify v1.6.1 golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 + golang.org/x/net v0.0.0-20211216030914-fe4d6282115f // indirect gopkg.in/yaml.v2 v2.3.0 ) diff --git a/go.sum b/go.sum index e0d24d73..4f597537 100644 --- a/go.sum +++ b/go.sum @@ -104,6 +104,8 @@ golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20211216030914-fe4d6282115f h1:hEYJvxw1lSnWIl8X9ofsYMklzaDs90JI2az5YMd4fPM= +golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -121,7 +123,13 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 h1:AvbQYmiaaaza3cW3QXRyPo5kYgpFIzOAfeAAN7m3qQ4= golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/integration_test/integration_test.go b/integration_test/integration_test.go index de674e7e..9f31ecc1 100644 --- a/integration_test/integration_test.go +++ b/integration_test/integration_test.go @@ -233,6 +233,7 @@ type udpRecord struct { type fakeUDPMetrics struct { metrics.ShadowsocksMetrics fakeLocation string + mu sync.Mutex up, down []udpRecord natAdded int } @@ -241,13 +242,19 @@ func (m *fakeUDPMetrics) GetLocation(addr net.Addr) (string, error) { return m.fakeLocation, nil } func (m *fakeUDPMetrics) AddUDPPacketFromClient(clientLocation, accessKey, status string, clientProxyBytes, proxyTargetBytes int, timeToCipher time.Duration) { + m.mu.Lock() m.up = append(m.up, udpRecord{clientLocation, accessKey, status, clientProxyBytes, proxyTargetBytes}) + m.mu.Unlock() } func (m *fakeUDPMetrics) AddUDPPacketFromTarget(clientLocation, accessKey, status string, targetProxyBytes, proxyClientBytes int) { + m.mu.Lock() m.down = append(m.down, udpRecord{clientLocation, accessKey, status, targetProxyBytes, proxyClientBytes}) + m.mu.Unlock() } func (m *fakeUDPMetrics) AddUDPNatEntry() { + m.mu.Lock() m.natAdded++ + m.mu.Unlock() } func (m *fakeUDPMetrics) RemoveUDPNatEntry() { // Not tested because it requires waiting for a long timeout. @@ -256,9 +263,9 @@ func (m *fakeUDPMetrics) RemoveUDPNatEntry() { func TestUDPEcho(t *testing.T) { echoConn, echoRunning := startUDPEchoServer(t) - proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + proxyConn, err := onet.ListenAnyUDP4(0) if err != nil { - t.Fatalf("ListenTCP failed: %v", err) + t.Fatalf("ListenAnyUDP4 failed: %v", err) } secrets := ss.MakeTestSecrets(1) cipherList, err := service.MakeTestCiphers(secrets) @@ -350,6 +357,118 @@ func TestUDPEcho(t *testing.T) { } } +// Test that UDP packets addressed to different proxy IPs produce replies +// from the corresponding proxy IP. +func TestUDPEchoMultipleIP(t *testing.T) { + echoConn, echoRunning := startUDPEchoServer(t) + + proxyConn, err := onet.ListenAnyUDP4(0) + if err != nil { + t.Fatalf("ListenAnyUDP4 failed: %v", err) + } + secrets := ss.MakeTestSecrets(1) + cipherList, err := service.MakeTestCiphers(secrets) + if err != nil { + t.Fatal(err) + } + testMetrics := &fakeUDPMetrics{fakeLocation: "QQ"} + proxy := service.NewUDPService(time.Hour, cipherList, testMetrics) + proxy.SetTargetIPValidator(allowAll) + go proxy.Serve(proxyConn) + + _, proxyPort, err := net.SplitHostPort(proxyConn.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + portNum, err := strconv.Atoi(proxyPort) + if err != nil { + t.Fatal(err) + } + + client1, err := client.NewClient("127.0.0.1", portNum, secrets[0], ss.TestCipher) + if err != nil { + t.Fatalf("Failed to create ShadowsocksClient: %v", err) + } + client2, err := client.NewClient("127.0.0.2", portNum, secrets[0], ss.TestCipher) + if err != nil { + t.Fatalf("Failed to create ShadowsocksClient: %v", err) + } + conn1, err := client1.ListenUDP(nil) + if err != nil { + t.Fatalf("ShadowsocksClient.ListenUDP failed: %v", err) + } + conn2, err := client2.ListenUDP(nil) + if err != nil { + t.Fatalf("ShadowsocksClient.ListenUDP failed: %v", err) + } + + const N = 1000 + up1 := ss.MakeTestPayload(N) + n, err := conn1.WriteTo(up1, echoConn.LocalAddr()) + if err != nil { + t.Fatal(err) + } + if n != N { + t.Fatalf("Tried to upload %d bytes, but only sent %d", N, n) + } + + up2 := ss.MakeTestPayload(N + 1) + n, err = conn2.WriteTo(up2, echoConn.LocalAddr()) + if err != nil { + t.Fatal(err) + } + if n != N+1 { + t.Fatalf("Tried to upload %d bytes, but only sent %d", N+1, n) + } + + down := make([]byte, N+1) + n, addr, err := conn1.ReadFrom(down) + if err != nil { + t.Fatal(err) + } + if n != N { + t.Errorf("Tried to download %d bytes, but only received %d", N, n) + } + if addr.String() != echoConn.LocalAddr().String() { + t.Errorf("Reported address mismatch: %s != %s", addr.String(), echoConn.LocalAddr().String()) + } + + if !bytes.Equal(up1, down[:n]) { + t.Fatal("Echo mismatch") + } + + n, addr, err = conn2.ReadFrom(down) + if err != nil { + t.Fatal(err) + } + if n != N+1 { + t.Errorf("Tried to download %d bytes, but only received %d", N+1, n) + } + if addr.String() != echoConn.LocalAddr().String() { + t.Errorf("Reported address mismatch: %s != %s", addr.String(), echoConn.LocalAddr().String()) + } + + if !bytes.Equal(up2, down) { + t.Fatal("Echo mismatch") + } + + conn1.Close() + conn2.Close() + echoConn.Close() + echoRunning.Wait() + proxy.GracefulStop() + // Verify that the expected number of metrics were reported. + if testMetrics.natAdded != 2 { + t.Errorf("Wrong NAT add count: %d", testMetrics.natAdded) + } + if len(testMetrics.up) != 2 { + t.Errorf("Wrong number of packets sent: %v", testMetrics.up) + } + if len(testMetrics.down) != 2 { + t.Errorf("Wrong number of packets received: %v", testMetrics.down) + } +} + func BenchmarkTCPThroughput(b *testing.B) { echoListener, echoRunning := startTCPEchoServer(b) @@ -496,7 +615,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { func BenchmarkUDPEcho(b *testing.B) { echoConn, echoRunning := startUDPEchoServer(b) - proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + proxyConn, err := onet.ListenAnyUDP4(0) if err != nil { b.Fatalf("ListenTCP failed: %v", err) } @@ -544,7 +663,7 @@ func BenchmarkUDPEcho(b *testing.B) { func BenchmarkUDPManyKeys(b *testing.B) { echoConn, echoRunning := startUDPEchoServer(b) - proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + proxyConn, err := onet.ListenAnyUDP4(0) if err != nil { b.Fatalf("ListenTCP failed: %v", err) } diff --git a/net/udp_any.go b/net/udp_any.go new file mode 100644 index 00000000..225f01ea --- /dev/null +++ b/net/udp_any.go @@ -0,0 +1,134 @@ +// Copyright 2022 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build darwin || linux + +package net + +import ( + "errors" + "io" + "net" + "runtime" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// UDPAnyConn extends net.PacketConn to allow reporting the destination IP +// of incoming packets, and setting the source IP of outgoing packets. This +// is relevant for UDP connections that are bound to `0.0.0.0` or `::`. In +// these cases, net.PacketConn is not sufficient to enable sending a reply +// from the expected source IP. +type UDPAnyConn interface { + net.PacketConn + ReadToFrom(p []byte) (n int, src *net.UDPAddr, dst net.IP, err error) + WriteToFrom(p []byte, dst *net.UDPAddr, src net.IP) (int, error) +} + +type udpAnyConnV4 struct { + net.PacketConn + v4 ipv4.PacketConn +} + +// ListenAnyUDP4 returns a UDPAnyConn that is listening on all IPv4 addresses +// at the specified port. If `port` is zero, the kernel will choose an open port. +func ListenAnyUDP4(port int) (UDPAnyConn, error) { + conn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port}) + if err != nil { + return nil, err + } + anyConn := &udpAnyConnV4{conn, *ipv4.NewPacketConn(conn)} + if err = anyConn.v4.SetControlMessage(ipv4.FlagDst, true); err != nil { + return nil, err + } + return anyConn, nil +} + +func (c *udpAnyConnV4) ReadToFrom(p []byte) (n int, src *net.UDPAddr, dst net.IP, err error) { + var cm *ipv4.ControlMessage + var tmpSrc net.Addr + if n, cm, tmpSrc, err = c.v4.ReadFrom(p); err != nil { + return + } + if cm != nil { + dst = cm.Dst + } else if runtime.GOOS != "windows" { + err = errors.New("control data is missing") + return + } + src = tmpSrc.(*net.UDPAddr) + return +} + +func (c *udpAnyConnV4) WriteToFrom(p []byte, dst *net.UDPAddr, src net.IP) (int, error) { + cm := &ipv4.ControlMessage{Src: src} + return c.v4.WriteTo(p, cm, dst) +} + +type udpAnyConnV6 struct { + net.PacketConn + v6 ipv6.PacketConn +} + +// ListenAnyUDP4 returns a UDPAnyConn that is listening on all IPv6 addresses +// at the specified port. If `port` is zero, the kernel will choose an open port. +func ListenAnyUDP6(port int) (UDPAnyConn, error) { + conn, err := net.ListenUDP("udp6", &net.UDPAddr{Port: port}) + if err != nil { + return nil, err + } + anyConn := &udpAnyConnV6{conn, *ipv6.NewPacketConn(conn)} + if err = anyConn.v6.SetControlMessage(ipv6.FlagDst, true); err != nil { + return nil, err + } + return anyConn, nil +} + +func (c *udpAnyConnV6) ReadToFrom(p []byte) (n int, src *net.UDPAddr, dst net.IP, err error) { + var cm *ipv6.ControlMessage + var tmpSrc net.Addr + if n, cm, tmpSrc, err = c.v6.ReadFrom(p); err != nil { + return + } + if cm != nil { + dst = cm.Dst + } else if runtime.GOOS != "windows" { + err = errors.New("control data is missing") + return + } + src = tmpSrc.(*net.UDPAddr) + return +} + +func (c *udpAnyConnV6) WriteToFrom(p []byte, dst *net.UDPAddr, src net.IP) (int, error) { + cm := &ipv6.ControlMessage{Src: src} + return c.v6.WriteTo(p, cm, dst) +} + +type boundWriter struct { + conn UDPAnyConn + dst *net.UDPAddr + src net.IP +} + +func (w boundWriter) Write(p []byte) (int, error) { + return w.conn.WriteToFrom(p, w.dst, w.src) +} + +// MakeBoundWriter returns a Writer that mimics the behavior of Write() on a +// connected UDPConn. +func MakeBoundWriter(conn UDPAnyConn, dst *net.UDPAddr, src net.IP) io.Writer { + return boundWriter{conn, dst, src} +} \ No newline at end of file diff --git a/net/udp_any_test.go b/net/udp_any_test.go new file mode 100644 index 00000000..9d80d8cb --- /dev/null +++ b/net/udp_any_test.go @@ -0,0 +1,260 @@ +// Copyright 2022 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "net" + "testing" +) + +func TestListenAnyUDP4(t *testing.T) { + server, err := ListenAnyUDP4(0) + if err != nil { + t.Fatal(err) + } + serverPort := server.LocalAddr().(*net.UDPAddr).Port + serverAddr1 := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: serverPort, + } + client1, err := net.DialUDP("udp", nil, serverAddr1) + if err != nil { + t.Fatal(err) + } + serverAddr2 := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.2"), + Port: serverPort, + } + client2, err := net.DialUDP("udp", nil, serverAddr2) + if err != nil { + t.Fatal(err) + } + + // Receive a packet on 127.0.0.1 + if _, err := client1.Write([]byte{1}); err != nil { + t.Fatal(err) + } + buf := make([]byte, 2) + n, src, dst, err := server.ReadToFrom(buf) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("Unexpected length: %d", n) + } + if buf[0] != 1 { + t.Errorf("Unexpected contents: %v", buf[:n]) + } + if src == nil { + t.Error("No source address") + } + if dst.String() != "127.0.0.1" { + t.Errorf("Unexpected destination: %v", dst) + } + + // Receive a packet on 127.0.0.2 + if _, err := client2.Write([]byte{2}); err != nil { + t.Fatal(err) + } + n, src, dst, err = server.ReadToFrom(buf) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("Unexpected length: %d", n) + } + if buf[0] != 2 { + t.Errorf("Unexpected contents: %v", buf[:n]) + } + if src == nil { + t.Error("No source address") + } + if dst.String() != "127.0.0.2" { + t.Errorf("Unexpected destination: %v", dst) + } +} + +func TestSendAnyUDP4(t *testing.T) { + server, err := ListenAnyUDP4(0) + if err != nil { + t.Fatal(err) + } + serverPort := server.LocalAddr().(*net.UDPAddr).Port + client, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}) + if err != nil { + t.Fatal(err) + } + clientAddr := client.LocalAddr().(*net.UDPAddr) + + serverIP1 := net.ParseIP("127.0.0.1") + serverIP2 := net.ParseIP("127.0.0.2") + + // Send from 127.0.0.1 + if _, err := server.WriteToFrom([]byte{1}, clientAddr, serverIP1); err != nil { + t.Fatal(err) + } + buf := make([]byte, 2) + n, src, err := client.ReadFrom(buf) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("Unexpected length: %d", n) + } + if buf[0] != 1 { + t.Errorf("Unexpected contents: %v", buf[:n]) + } + udpSrc := src.(*net.UDPAddr) + if !udpSrc.IP.Equal(serverIP1) { + t.Errorf("Wrong source IP: %v", src) + } + if udpSrc.Port != serverPort { + t.Errorf("Wrong source port: %v", src) + } + + // Send from 127.0.0.2 + if _, err := server.WriteToFrom([]byte{2}, clientAddr, serverIP2); err != nil { + t.Fatal(err) + } + n, src, err = client.ReadFrom(buf) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("Unexpected length: %d", n) + } + if buf[0] != 2 { + t.Errorf("Unexpected contents: %v", buf[:n]) + } + udpSrc = src.(*net.UDPAddr) + if !udpSrc.IP.Equal(serverIP2) { + t.Errorf("Wrong source IP: %v", src) + } + if udpSrc.Port != serverPort { + t.Errorf("Wrong source port: %v", src) + } +} + +func TestListenAnyUDP6(t *testing.T) { + server, err := ListenAnyUDP6(0) + if err != nil { + t.Fatal(err) + } + serverPort := server.LocalAddr().(*net.UDPAddr).Port + interfaces, err := net.Interfaces() + if err != nil { + t.Fatal(err) + } + for _, iface := range interfaces { + addrs, err := iface.Addrs() + if err != nil { + t.Fatal(err) + } + for _, addr := range addrs { + ip, _, err := net.ParseCIDR(addr.String()) + if err != nil { + t.Fatal(err) + } + if ip.To4() != nil { + continue // Ignore IPv4 + } + + // Receive a packet on this IP address. + serverAddr := &net.UDPAddr{IP: ip, Port: serverPort, Zone: iface.Name} + client, err := net.DialUDP("udp6", nil, serverAddr) + if err != nil { + t.Fatal(err) + } + if _, err := client.Write([]byte{1}); err != nil { + t.Fatal(err) + } + buf := make([]byte, 2) + n, src, dst, err := server.ReadToFrom(buf) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("Unexpected length: %d", n) + } + if buf[0] != 1 { + t.Errorf("Unexpected contents: %v", buf[:n]) + } + if src == nil { + t.Error("No source address") + } + if !ip.Equal(dst) { + t.Errorf("Unexpected destination: %v", dst) + } + } + } +} + +func TestSendAnyUDP6(t *testing.T) { + server, err := ListenAnyUDP6(0) + if err != nil { + t.Fatal(err) + } + serverPort := server.LocalAddr().(*net.UDPAddr).Port + interfaces, err := net.Interfaces() + if err != nil { + t.Fatal(err) + } + for _, iface := range interfaces { + addrs, err := iface.Addrs() + if err != nil { + t.Fatal(err) + } + for _, addr := range addrs { + ip, _, err := net.ParseCIDR(addr.String()) + if err != nil { + t.Fatal(err) + } + if ip.To4() != nil { + continue // Ignore IPv4 + } + + // Start a client listening on this IP. + clientInitAddr := &net.UDPAddr{IP: ip, Zone: iface.Name} + client, err := net.ListenUDP("udp6", clientInitAddr) + if err != nil { + t.Fatal(err) + } + clientAddr := client.LocalAddr().(*net.UDPAddr) + + // Send a packet to the client from the same IP. This should + // avoid any issues with cross-interface routing rules. + if _, err := server.WriteToFrom([]byte{1}, clientAddr, ip); err != nil { + t.Fatal(err) + } + buf := make([]byte, 2) + n, src, err := client.ReadFromUDP(buf) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("Unexpected length: %d", n) + } + if buf[0] != 1 { + t.Errorf("Unexpected contents: %v", buf[:n]) + } + if !src.IP.Equal(ip) { + t.Errorf("Unexpected source IP (%v)", src.IP) + } + if src.Port != serverPort { + t.Errorf("Unexpected source port: %d", src.Port) + } + } + } +} diff --git a/server.go b/server.go index c3c04121..96324d53 100644 --- a/server.go +++ b/server.go @@ -28,6 +28,7 @@ import ( "syscall" "time" + onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/Jigsaw-Code/outline-ss-server/service" "github.com/Jigsaw-Code/outline-ss-server/service/metrics" ss "github.com/Jigsaw-Code/outline-ss-server/shadowsocks" @@ -62,9 +63,10 @@ func init() { } type ssPort struct { - tcpService service.TCPService - udpService service.UDPService - cipherList service.CipherList + tcpService service.TCPService + udp4Service service.UDPService + udp6Service service.UDPService + cipherList service.CipherList } type SSServer struct { @@ -79,18 +81,25 @@ func (s *SSServer) startPort(portNum int) error { if err != nil { return fmt.Errorf("Failed to start TCP on port %v: %v", portNum, err) } - packetConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: portNum}) - if err != nil { - return fmt.Errorf("Failed to start UDP on port %v: %v", portNum, err) + udp4Conn, udp4err := onet.ListenAnyUDP4(portNum) + udp6Conn, udp6err := onet.ListenAnyUDP6(portNum) + if udp4err != nil && udp6err != nil { + return fmt.Errorf("Failed to start UDP on port %v: %v", portNum, udp4err) } logger.Infof("Listening TCP and UDP on port %v", portNum) port := &ssPort{cipherList: service.NewCipherList()} + s.ports[portNum] = port // TODO: Register initial data metrics at zero. port.tcpService = service.NewTCPService(port.cipherList, &s.replayCache, s.m, tcpReadTimeout) - port.udpService = service.NewUDPService(s.natTimeout, port.cipherList, s.m) - s.ports[portNum] = port go port.tcpService.Serve(listener) - go port.udpService.Serve(packetConn) + if udp4err == nil { + port.udp4Service = service.NewUDPService(s.natTimeout, port.cipherList, s.m) + go port.udp4Service.Serve(udp4Conn) + } + if udp6err == nil { + port.udp6Service = service.NewUDPService(s.natTimeout, port.cipherList, s.m) + go port.udp6Service.Serve(udp6Conn) + } return nil } @@ -100,13 +109,20 @@ func (s *SSServer) removePort(portNum int) error { return fmt.Errorf("Port %v doesn't exist", portNum) } tcpErr := port.tcpService.Stop() - udpErr := port.udpService.Stop() + var udp4Err, udp6Err error + if port.udp4Service != nil { + udp4Err = port.udp4Service.Stop() + } + if port.udp6Service != nil { + udp6Err = port.udp6Service.Stop() + } delete(s.ports, portNum) if tcpErr != nil { return fmt.Errorf("Failed to close listener on %v: %v", portNum, tcpErr) - } - if udpErr != nil { - return fmt.Errorf("Failed to close packetConn on %v: %v", portNum, udpErr) + } else if udp4Err != nil { + return fmt.Errorf("Failed to stop IPv4 UDP service on %v: %v", portNum, udp4Err) + } else if udp6Err != nil { + return fmt.Errorf("Failed to stop IPv6 UDP service on %v: %v", portNum, udp6Err) } logger.Infof("Stopped TCP and UDP on port %v", portNum) return nil diff --git a/service/udp.go b/service/udp.go index 42160497..7fb9b190 100644 --- a/service/udp.go +++ b/service/udp.go @@ -17,6 +17,7 @@ package service import ( "errors" "fmt" + "io" "net" "runtime/debug" "sync" @@ -71,7 +72,7 @@ func findAccessKeyUDP(clientIP net.IP, dst, src []byte, cipherList CipherList) ( type udpService struct { mu sync.RWMutex // Protects .clientConn and .stopped - clientConn net.PacketConn + clientConn onet.UDPAnyConn stopped bool natTimeout time.Duration ciphers CipherList @@ -90,7 +91,7 @@ type UDPService interface { // SetTargetIPValidator sets the function to be used to validate the target IP addresses. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) // Serve adopts the clientConn, and will not return until it is closed by Stop(). - Serve(clientConn net.PacketConn) error + Serve(clientConn onet.UDPAnyConn) error // Stop closes the clientConn and prevents further forwarding of packets. Stop() error // GracefulStop calls Stop(), and then blocks until all resources have been cleaned up. @@ -103,7 +104,7 @@ func (s *udpService) SetTargetIPValidator(targetIPValidator onet.TargetIPValidat // Listen on addr for encrypted packets and basically do UDP NAT. // We take the ciphers as a pointer because it gets replaced on config updates. -func (s *udpService) Serve(clientConn net.PacketConn) error { +func (s *udpService) Serve(clientConn onet.UDPAnyConn) error { s.mu.Lock() if s.clientConn != nil { s.mu.Unlock() @@ -135,7 +136,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { }() // Attempt to read an upstream packet. - clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) + clientProxyBytes, clientAddr, proxyIP, err := clientConn.ReadToFrom(cipherBuf) if err != nil { s.mu.RLock() stopped = s.stopped @@ -171,7 +172,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { cipherData := cipherBuf[:clientProxyBytes] var payload []byte var tgtUDPAddr *net.UDPAddr - targetConn := nm.Get(clientAddr.String()) + targetConn := nm.Get(clientAddr, proxyIP) if targetConn == nil { var locErr error clientLocation, locErr = s.m.GetLocation(clientAddr) @@ -180,7 +181,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { } debugUDPAddr(clientAddr, "Got location \"%s\"", clientLocation) - ip := clientAddr.(*net.UDPAddr).IP + ip := clientAddr.IP var textData []byte var cipher *ss.Cipher unpackStart := time.Now() @@ -200,7 +201,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { if err != nil { return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) } - targetConn = nm.Add(clientAddr, clientConn, cipher, udpConn, clientLocation, keyID) + targetConn = nm.Add(clientAddr, proxyIP, clientConn, cipher, udpConn, clientLocation, keyID) } else { clientLocation = targetConn.clientLocation @@ -335,10 +336,19 @@ func (c *natconn) ReadFrom(buf []byte) (int, net.Addr, error) { return n, addr, err } +type natkey struct { + clientAddr string // TODO: Use netip.AddrPort + proxyIP string // TODO: Use netip.Addr +} + +func makeNATKey(clientAddr *net.UDPAddr, proxyIP net.IP) natkey { + return natkey{clientAddr.String(), proxyIP.String()} +} + // Packet NAT table type natmap struct { sync.RWMutex - keyConn map[string]*natconn + keyConn map[natkey]*natconn timeout time.Duration metrics metrics.ShadowsocksMetrics running *sync.WaitGroup @@ -346,18 +356,18 @@ type natmap struct { func newNATmap(timeout time.Duration, sm metrics.ShadowsocksMetrics, running *sync.WaitGroup) *natmap { m := &natmap{metrics: sm, running: running} - m.keyConn = make(map[string]*natconn) + m.keyConn = make(map[natkey]*natconn) m.timeout = timeout return m } -func (m *natmap) Get(key string) *natconn { +func (m *natmap) Get(clientAddr *net.UDPAddr, proxyIP net.IP) *natconn { m.RLock() defer m.RUnlock() - return m.keyConn[key] + return m.keyConn[makeNATKey(clientAddr, proxyIP)] } -func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, keyID, clientLocation string) *natconn { +func (m *natmap) set(key natkey, pc net.PacketConn, cipher *ss.Cipher, keyID, clientLocation string) *natconn { entry := &natconn{ PacketConn: pc, cipher: cipher, @@ -373,7 +383,7 @@ func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, keyID, cl return entry } -func (m *natmap) del(key string) net.PacketConn { +func (m *natmap) del(key natkey) net.PacketConn { m.Lock() defer m.Unlock() @@ -385,15 +395,19 @@ func (m *natmap) del(key string) net.PacketConn { return nil } -func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cipher *ss.Cipher, targetConn net.PacketConn, clientLocation, keyID string) *natconn { - entry := m.set(clientAddr.String(), targetConn, cipher, keyID, clientLocation) +func (m *natmap) Add(clientAddr *net.UDPAddr, proxyIP net.IP, clientConn onet.UDPAnyConn, cipher *ss.Cipher, targetConn net.PacketConn, clientLocation, keyID string) *natconn { + key := makeNATKey(clientAddr, proxyIP) + entry := m.set(key, targetConn, cipher, keyID, clientLocation) + + boundWriter := onet.MakeBoundWriter(clientConn, clientAddr, proxyIP) + ssWriter := makeShadowsocksUDPWriter(boundWriter, cipher) m.metrics.AddUDPNatEntry() m.running.Add(1) go func() { - timedCopy(clientAddr, clientConn, entry, keyID, m.metrics) + copyUntilTimeout(ssWriter, entry, keyID, m.metrics, clientAddr) m.metrics.RemoveUDPNatEntry() - if pc := m.del(clientAddr.String()); pc != nil { + if pc := m.del(key); pc != nil { pc.Close() } m.running.Done() @@ -420,75 +434,84 @@ func (m *natmap) Close() error { var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) // copy from target to client until read timeout -func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, - keyID string, sm metrics.ShadowsocksMetrics) { - // pkt is used for in-place encryption of downstream UDP packets, with the layout - // [padding?][salt][address][body][tag][extra] - // Padding is only used if the address is IPv4. - pkt := make([]byte, serverUDPBufferSize) - - saltSize := targetConn.cipher.SaltSize() - // Leave enough room at the beginning of the packet for a max-length header (i.e. IPv6). - bodyStart := saltSize + maxAddrLen - - expired := false +func copyUntilTimeout(ssWriter shadowsocksUDPWriter, targetConn *natconn, + keyID string, sm metrics.ShadowsocksMetrics, clientAddr *net.UDPAddr) { for { - var bodyLen, proxyClientBytes int - connError := func() (connError *onet.ConnectionError) { - var ( - raddr net.Addr - err error - ) - // `readBuf` receives the plaintext body in `pkt`: - // [padding?][salt][address][body][tag][unused] - // |-- bodyStart --|[ readBuf ] - readBuf := pkt[bodyStart:] - bodyLen, raddr, err = targetConn.ReadFrom(readBuf) - if err != nil { - if netErr, ok := err.(net.Error); ok { - if netErr.Timeout() { - expired = true - return nil - } - } - return onet.NewConnectionError("ERR_READ", "Failed to read from target", err) - } - - debugUDPAddr(clientAddr, "Got response from %v", raddr) - srcAddr := socks.ParseAddr(raddr.String()) - addrStart := bodyStart - len(srcAddr) - // `plainTextBuf` concatenates the SOCKS address and body: - // [padding?][salt][address][body][tag][unused] - // |-- addrStart -|[plaintextBuf ] - plaintextBuf := pkt[addrStart : bodyStart+bodyLen] - copy(plaintextBuf, srcAddr) - - // saltStart is 0 if raddr is IPv6. - saltStart := addrStart - saltSize - // `packBuf` adds space for the salt and tag. - // `buf` shows the space that was used. - // [padding?][salt][address][body][tag][unused] - // [ packBuf ] - // [ buf ] - packBuf := pkt[saltStart:] - buf, err := ss.Pack(packBuf, plaintextBuf, targetConn.cipher) // Encrypt in-place - if err != nil { - return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) - } - proxyClientBytes, err = clientConn.WriteTo(buf, clientAddr) - if err != nil { - return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err) - } - return nil - }() + bodyLen, proxyClientBytes, raddr, connError := ssWriter.ReadPacketFrom(targetConn) status := "OK" if connError != nil { logger.Debugf("UDP Error: %v: %v", connError.Message, connError.Cause) status = connError.Status + if status == "ERR_READ" { + if netErr, ok := connError.Cause.(net.Error); ok && netErr.Timeout() { + break + } + } } - if expired { - break - } + debugUDPAddr(clientAddr, "Got response from (%v)", raddr) sm.AddUDPPacketFromTarget(targetConn.clientLocation, keyID, status, bodyLen, proxyClientBytes) } } + +// Represents net.PacketConn.ReadFrom. +type packetReader interface { + ReadFrom(p []byte) (int, net.Addr, error) +} + +// Parallel to ss.ShadowsocksWriter, but for UDP. +type shadowsocksUDPWriter struct { + clientWriter io.Writer + cipher *ss.Cipher + pkt [serverUDPBufferSize]byte +} + +func makeShadowsocksUDPWriter(clientWriter io.Writer, cipher *ss.Cipher) shadowsocksUDPWriter { + return shadowsocksUDPWriter{ + clientWriter: clientWriter, + cipher: cipher, + } +} + +// ReadPacketFrom reads one plaintext packet from `r`, encodes it for shadowsocks, +// and sends it to the client. It returns the number of bytes read and written and +// the source address of the packet, or an error. +func (w *shadowsocksUDPWriter) ReadPacketFrom(r packetReader) (readLen int, writeLen int, raddr net.Addr, connErr *onet.ConnectionError) { + saltSize := w.cipher.SaltSize() + bodyStart := saltSize + maxAddrLen + // `readBuf` receives the plaintext body in `pkt`: + // [padding?][salt][address][body][tag][unused] + // |-- bodyStart --|[ readBuf ] + readBuf := w.pkt[bodyStart:] + var err error + if readLen, raddr, err = r.ReadFrom(readBuf); err != nil { + connErr = onet.NewConnectionError("ERR_READ", "Failed to read from target", err) + return + } + + srcAddr := socks.ParseAddr(raddr.String()) + addrStart := bodyStart - len(srcAddr) + // `plainTextBuf` concatenates the SOCKS address and body: + // [padding?][salt][address][body][tag][unused] + // |-- addrStart -|[plaintextBuf ] + plaintextBuf := w.pkt[addrStart : bodyStart+readLen] + copy(plaintextBuf, srcAddr) + + // saltStart is 0 if raddr is IPv6. + saltStart := addrStart - saltSize + // `packBuf` adds space for the salt and tag. + // `buf` shows the space that was used. + // [padding?][salt][address][body][tag][unused] + // [ packBuf ] + // [ buf ] + packBuf := w.pkt[saltStart:] + buf, err := ss.Pack(packBuf, plaintextBuf, w.cipher) // Encrypt in-place + if err != nil { + connErr = onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) + return + } + if writeLen, err = w.clientWriter.Write(buf); err != nil { + connErr = onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err) + return + } + return +} diff --git a/service/udp_test.go b/service/udp_test.go index 12139fdc..53cd495c 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -17,6 +17,7 @@ package service import ( "bytes" "errors" + "io" "net" "sync" "testing" @@ -35,6 +36,7 @@ const timeout = 5 * time.Minute var clientAddr = net.UDPAddr{IP: []byte{192, 0, 2, 1}, Port: 12345} var targetAddr = net.UDPAddr{IP: []byte{192, 0, 2, 2}, Port: 54321} var dnsAddr = net.UDPAddr{IP: []byte{192, 0, 2, 3}, Port: 53} +var proxyIP net.IP = []byte{192, 0, 2, 4} var natCipher *ss.Cipher func init() { @@ -43,7 +45,8 @@ func init() { } type packet struct { - addr net.Addr + remote *net.UDPAddr + local net.IP payload []byte err error } @@ -68,20 +71,29 @@ func (conn *fakePacketConn) SetReadDeadline(deadline time.Time) error { } func (conn *fakePacketConn) WriteTo(payload []byte, addr net.Addr) (int, error) { - conn.send <- packet{addr, payload, nil} + return conn.WriteToFrom(payload, addr.(*net.UDPAddr), nil) +} + +func (conn *fakePacketConn) WriteToFrom(payload []byte, dst *net.UDPAddr, src net.IP) (int, error) { + conn.send <- packet{dst, src, payload, nil} return len(payload), nil } func (conn *fakePacketConn) ReadFrom(buffer []byte) (int, net.Addr, error) { + n, src, _, err := conn.ReadToFrom(buffer) + return n, src, err +} + +func (conn *fakePacketConn) ReadToFrom(buffer []byte) (int, *net.UDPAddr, net.IP, error) { pkt, ok := <-conn.recv if !ok { - return 0, nil, errors.New("Receive closed") + return 0, nil, nil, errors.New("Receive closed") } n := copy(buffer, pkt.payload) if n < len(pkt.payload) { - return n, pkt.addr, errors.New("Buffer was too short") + return n, pkt.remote, pkt.local, io.ErrShortBuffer } - return n, pkt.addr, pkt.err + return n, pkt.remote, pkt.local, pkt.err } func (conn *fakePacketConn) Close() error { @@ -141,10 +153,11 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest ciphertext := make([]byte, cipher.SaltSize()+len(plaintext)+cipher.TagSize()) ss.Pack(ciphertext, plaintext, cipher) clientConn.recv <- packet{ - addr: &net.UDPAddr{ + remote: &net.UDPAddr{ IP: net.ParseIP("192.0.2.1"), Port: 54321, }, + local: proxyIP, payload: ciphertext, } } @@ -203,7 +216,7 @@ func assertAlmostEqual(t *testing.T, a, b time.Time) { func TestNATEmpty(t *testing.T) { nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) - if nat.Get("foo") != nil { + if nat.Get(&clientAddr, proxyIP) != nil { t.Error("Expected nil value from empty NAT map") } } @@ -212,8 +225,8 @@ func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) { nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) clientConn := makePacketConn() targetConn := makePacketConn() - nat.Add(&clientAddr, clientConn, natCipher, targetConn, "ZZ", "key id") - entry := nat.Get(clientAddr.String()) + nat.Add(&clientAddr, proxyIP, clientConn, natCipher, targetConn, "ZZ", "key id") + entry := nat.Get(&clientAddr, proxyIP) return clientConn, targetConn, entry } @@ -238,8 +251,8 @@ func TestNATWrite(t *testing.T) { if !bytes.Equal(sent.payload, buf) { t.Errorf("Mismatched payload: %v != %v", sent.payload, buf) } - if sent.addr != &targetAddr { - t.Errorf("Mismatched address: %v != %v", sent.addr, &targetAddr) + if sent.remote != &targetAddr { + t.Errorf("Mismatched address: %v != %v", sent.remote, &targetAddr) } } @@ -255,8 +268,8 @@ func TestNATWriteDNS(t *testing.T) { if !bytes.Equal(sent.payload, buf) { t.Errorf("Mismatched payload: %v != %v", sent.payload, buf) } - if sent.addr != &dnsAddr { - t.Errorf("Mismatched address: %v != %v", sent.addr, &targetAddr) + if sent.remote != &dnsAddr { + t.Errorf("Mismatched address: %v != %v", sent.remote, &targetAddr) } } @@ -297,7 +310,7 @@ func TestNATFastClose(t *testing.T) { sent := <-targetConn.send // Send the response. response := []byte{1, 2, 3, 4, 5} - received := packet{addr: &dnsAddr, payload: response} + received := packet{remote: &dnsAddr, payload: response} targetConn.recv <- received sent, ok := <-clientConn.send if !ok { @@ -306,8 +319,11 @@ func TestNATFastClose(t *testing.T) { if len(sent.payload) <= len(response) { t.Error("Packet is too short to be shadowsocks-AEAD") } - if sent.addr != &clientAddr { - t.Errorf("Address mismatch: %v != %v", sent.addr, clientAddr) + if sent.remote != &clientAddr { + t.Errorf("Address mismatch: %v != %v", sent.remote, clientAddr) + } + if !proxyIP.Equal(sent.local) { + t.Errorf("Proxy IP mismatch: %v != %v", sent.local, proxyIP) } // targetConn should be scheduled to close immediately. @@ -323,7 +339,7 @@ func TestNATNoFastClose_NotDNS(t *testing.T) { sent := <-targetConn.send // Send the response. response := []byte{1, 2, 3, 4, 5} - received := packet{addr: &targetAddr, payload: response} + received := packet{remote: &targetAddr, payload: response} targetConn.recv <- received sent, ok := <-clientConn.send if !ok { @@ -332,8 +348,8 @@ func TestNATNoFastClose_NotDNS(t *testing.T) { if len(sent.payload) <= len(response) { t.Error("Packet is too short to be shadowsocks-AEAD") } - if sent.addr != &clientAddr { - t.Errorf("Address mismatch: %v != %v", sent.addr, clientAddr) + if sent.remote != &clientAddr { + t.Errorf("Address mismatch: %v != %v", sent.remote, clientAddr) } // targetConn should be scheduled to close after the full timeout. assertAlmostEqual(t, targetConn.deadline, time.Now().Add(timeout)) @@ -352,7 +368,7 @@ func TestNATNoFastClose_MultipleDNS(t *testing.T) { // Send a response. response := []byte{1, 2, 3, 4, 5} - received := packet{addr: &dnsAddr, payload: response} + received := packet{remote: &dnsAddr, payload: response} targetConn.recv <- received <-clientConn.send @@ -391,6 +407,65 @@ func TestNATTimeout(t *testing.T) { assertAlmostEqual(t, before, time.Now()) } +func TestNATMultipleProxyIPs(t *testing.T) { + nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) + clientConn := makePacketConn() + targetConn1 := makePacketConn() + nat.Add(&clientAddr, proxyIP, clientConn, natCipher, targetConn1, "ZZ", "key id") + entry1 := nat.Get(&clientAddr, proxyIP) + targetConn2 := makePacketConn() + proxyIP2 := net.ParseIP("192.0.2.123") + nat.Add(&clientAddr, proxyIP2, clientConn, natCipher, targetConn2, "ZZ", "key id") + entry2 := nat.Get(&clientAddr, proxyIP2) + + // Send a standard packet on entry 1. + entry1.WriteTo([]byte{1}, &targetAddr) + assertAlmostEqual(t, targetConn1.deadline, time.Now().Add(timeout)) + <-targetConn1.send + + // Send a DNS packet on entry 2. + entry2.WriteTo([]byte{2}, &dnsAddr) + // DNS-only connections have a fixed timeout of 17 seconds. + assertAlmostEqual(t, targetConn2.deadline, time.Now().Add(17*time.Second)) + <-targetConn2.send + + // Send a reply on entry 1 and verify that it is sent from `proxyIP`. + targetConn1.recv <- packet{&targetAddr, nil, []byte{3}, nil} + ss1, ok := <-clientConn.send + if !ok { + t.Error("clientConn was closed") + } + if len(ss1.payload) <= 1 { + t.Error("Packet is too short to be shadowsocks-AEAD") + } + if ss1.remote != &clientAddr { + t.Errorf("Address mismatch: %v != %v", ss1.remote, clientAddr) + } + if !proxyIP.Equal(ss1.local) { + t.Errorf("Mismatched proxy IP: %v != %v", ss1.local, proxyIP) + } + // `targetConn1` is not DNS, so it's still open. + assertAlmostEqual(t, targetConn1.deadline, time.Now().Add(timeout)) + + // Send a reply on entry 2 and verify that it is sent from `proxyIP2`. + targetConn2.recv <- packet{&dnsAddr, nil, []byte{4}, nil} + ss2, ok := <-clientConn.send + if !ok { + t.Error("clientConn was closed") + } + if len(ss2.payload) <= 1 { + t.Error("Packet is too short to be shadowsocks-AEAD") + } + if ss2.remote != &clientAddr { + t.Errorf("Address mismatch: %v != %v", ss2.remote, clientAddr) + } + if !proxyIP2.Equal(ss2.local) { + t.Errorf("Mismatched proxy IP: %v != %v", ss2.local, proxyIP) + } + // `targetConn2`` should be scheduled to close immediately. + assertAlmostEqual(t, targetConn2.deadline, time.Now()) +} + // Simulates receiving invalid UDP packets on a server with 100 ciphers. func BenchmarkUDPUnpackFail(b *testing.B) { cipherList, err := MakeTestCiphers(ss.MakeTestSecrets(100)) @@ -451,6 +526,9 @@ func BenchmarkUDPUnpackSharedKey(b *testing.B) { snapshot := cipherList.SnapshotForClientIP(nil) cipher := snapshot[0].Value.(*CipherEntry).Cipher packet, err := ss.Pack(make([]byte, serverUDPBufferSize), plaintext, cipher) + if err != nil { + b.Fatal(err) + } const numIPs = 100 // Must be <256 ips := [numIPs]net.IP{} @@ -478,12 +556,8 @@ func TestUDPDoubleServe(t *testing.T) { c := make(chan error) for i := 0; i < 2; i++ { - clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) - if err != nil { - t.Fatalf("ListenUDP failed: %v", err) - } go func() { - err := s.Serve(clientConn) + err := s.Serve(makePacketConn()) if err != nil { c <- err close(c) @@ -513,11 +587,7 @@ func TestUDPEarlyStop(t *testing.T) { if err := s.Stop(); err != nil { t.Error(err) } - clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) - if err != nil { - t.Fatalf("ListenUDP failed: %v", err) - } - if err := s.Serve(clientConn); err != nil { + if err := s.Serve(makePacketConn()); err != nil { t.Error(err) } }