From 3a03eb950dda024cf5f1a77f703011f452a92aaa Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Fri, 2 Aug 2024 22:31:39 +0200 Subject: [PATCH] add an automated test for UDP port forwarding --- port-forwarder/config_test.go | 18 +++- port-forwarder/port_forwarder_udp_test.go | 122 ++++++++++++++++++++++ service/service_test.go | 92 +--------------- service/service_testhelpers.go | 100 ++++++++++++++++++ 4 files changed, 240 insertions(+), 92 deletions(-) create mode 100644 port-forwarder/port_forwarder_udp_test.go create mode 100644 service/service_testhelpers.go diff --git a/port-forwarder/config_test.go b/port-forwarder/config_test.go index e518317dd..8004460c9 100644 --- a/port-forwarder/config_test.go +++ b/port-forwarder/config_test.go @@ -46,7 +46,7 @@ port_forwarding: assert.True(t, fwd_list.IsEmpty()) } -func TestConfigWithNoProtocols2(t *testing.T) { +func TestConfigWithNoProtocols_commentedProtos(t *testing.T) { l := logrus.New() c := config.NewC(l) err := c.LoadString(` @@ -70,6 +70,22 @@ port_forwarding: assert.True(t, fwd_list.IsEmpty()) } +func TestConfigWithNoProtocols_missing_in_out(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 0) + assert.True(t, fwd_list.IsEmpty()) +} + func TestConfigWithTcpIn(t *testing.T) { l := logrus.New() c := config.NewC(l) diff --git a/port-forwarder/port_forwarder_udp_test.go b/port-forwarder/port_forwarder_udp_test.go new file mode 100644 index 000000000..1768f8808 --- /dev/null +++ b/port-forwarder/port_forwarder_udp_test.go @@ -0,0 +1,122 @@ +package port_forwarder + +import ( + "net" + "testing" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/service" + "github.com/stretchr/testify/assert" +) + +func createPortForwarderFromConfigString(l *logrus.Logger, srv *service.Service, configStr string) (*PortForwardingService, error) { + c := config.NewC(l) + err := c.LoadString(configStr) + if err != nil { + return nil, err + } + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + if err != nil { + return nil, err + } + + pf, err := ConstructFromInitialFwdList(srv, l, &fwd_list) + if err != nil { + return nil, err + } + + err = pf.Activate() + if err != nil { + return nil, err + } + + return pf, nil +} + +func doTestUdpCommunication( + t *testing.T, + msg string, + senderConn *net.UDPConn, + toAddr net.Addr, + receiverConn *net.UDPConn, +) (senderAddr net.Addr) { + data_sent := []byte(msg) + var n int + var err error + if toAddr != nil { + n, err = senderConn.WriteTo(data_sent, toAddr) + } else { + n, err = senderConn.Write(data_sent) + } + assert.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + + buf := make([]byte, 100) + n, senderAddr, err = receiverConn.ReadFrom(buf) + assert.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + assert.Equal(t, data_sent, buf[:n]) + return +} + +func TestUdpInOut2Clients(t *testing.T) { + l := logrus.New() + server, client := service.CreateTwoConnectedServices() + defer client.Close() + defer server.Close() + + server_pf, err := createPortForwarderFromConfigString(l, server, ` +port_forwarding: + inbound: + - listen_port: 4499 + dial_address: 127.0.0.1:5599 + protocols: [udp] +`) + assert.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + client_pf, err := createPortForwarderFromConfigString(l, client, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 10.0.0.1:4499 + protocols: [udp] +`) + assert.Nil(t, err) + + assert.Len(t, client_pf.portForwardings, 1) + + client_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3399") + assert.Nil(t, err) + server_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5599") + assert.Nil(t, err) + + server_listen_conn, err := net.ListenUDP("udp", server_conn_addr) + assert.Nil(t, err) + client1_conn, err := net.DialUDP("udp", nil, client_conn_addr) + assert.Nil(t, err) + client2_conn, err := net.DialUDP("udp", nil, client_conn_addr) + assert.Nil(t, err) + + client1_addr := doTestUdpCommunication(t, "Hello from client 1 side!", + client1_conn, nil, server_listen_conn) + assert.NotNil(t, client1_addr) + client2_addr := doTestUdpCommunication(t, "Hello from client two side!", + client2_conn, nil, server_listen_conn) + assert.NotNil(t, client2_addr) + + doTestUdpCommunication(t, "Hello from server first side!", + server_listen_conn, client1_addr, client1_conn) + doTestUdpCommunication(t, "Hello from server second side!", + server_listen_conn, client2_addr, client2_conn) + doTestUdpCommunication(t, "Hello from server third side!", + server_listen_conn, client1_addr, client1_conn) + + doTestUdpCommunication(t, "Hello from client two side AGAIN!", + client2_conn, nil, server_listen_conn) + +} diff --git a/service/service_test.go b/service/service_test.go index 29ff7041e..da327af42 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -4,103 +4,13 @@ import ( "bytes" "context" "errors" - "net/netip" "testing" - "time" - "dario.cat/mergo" - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/e2e" "golang.org/x/sync/errgroup" - "gopkg.in/yaml.v2" ) -type m map[string]interface{} - -func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) - caB, err := caCrt.MarshalToPEM() - if err != nil { - panic(err) - } - - mc := m{ - "pki": m{ - "ca": string(caB), - "cert": string(myPEM), - "key": string(myPrivKey), - }, - //"tun": m{"disabled": true}, - "firewall": m{ - "outbound": []m{{ - "proto": "any", - "port": "any", - "host": "any", - }}, - "inbound": []m{{ - "proto": "any", - "port": "any", - "host": "any", - }}, - }, - "timers": m{ - "pending_deletion_interval": 2, - "connection_alive_interval": 2, - }, - "handshakes": m{ - "try_interval": "200ms", - }, - } - - if overrides != nil { - err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) - if err != nil { - panic(err) - } - mc = overrides - } - - cb, err := yaml.Marshal(mc) - if err != nil { - panic(err) - } - - var c config.C - if err := c.LoadString(string(cb)); err != nil { - panic(err) - } - - l := logrus.New() - s, err := New(&c, l) - if err != nil { - panic(err) - } - return s -} - func TestService(t *testing.T) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ - "static_host_map": m{}, - "lighthouse": m{ - "am_lighthouse": true, - }, - "listen": m{ - "host": "0.0.0.0", - "port": 4243, - }, - }) - b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ - "static_host_map": m{ - "10.0.0.1": []string{"localhost:4243"}, - }, - "lighthouse": m{ - "hosts": []string{"10.0.0.1"}, - "interval": 1, - }, - }) + a, b := CreateTwoConnectedServices() ln, err := a.Listen("tcp", ":1234") if err != nil { diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go new file mode 100644 index 000000000..01d226694 --- /dev/null +++ b/service/service_testhelpers.go @@ -0,0 +1,100 @@ +package service + +import ( + "net/netip" + "time" + + "dario.cat/mergo" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/e2e" + "gopkg.in/yaml.v2" +) + +type m map[string]interface{} + +func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) + caB, err := caCrt.MarshalToPEM() + if err != nil { + panic(err) + } + + mc := m{ + "pki": m{ + "ca": string(caB), + "cert": string(myPEM), + "key": string(myPrivKey), + }, + //"tun": m{"disabled": true}, + "firewall": m{ + "outbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + }, + "timers": m{ + "pending_deletion_interval": 2, + "connection_alive_interval": 2, + }, + "handshakes": m{ + "try_interval": "200ms", + }, + } + + if overrides != nil { + err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = overrides + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + var c config.C + if err := c.LoadString(string(cb)); err != nil { + panic(err) + } + + l := logrus.New() + s, err := New(&c, l) + if err != nil { + panic(err) + } + return s +} + +func CreateTwoConnectedServices() (*Service, *Service) { + ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ + "static_host_map": m{}, + "lighthouse": m{ + "am_lighthouse": true, + }, + "listen": m{ + "host": "0.0.0.0", + "port": 4243, + }, + }) + b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ + "static_host_map": m{ + "10.0.0.1": []string{"localhost:4243"}, + }, + "lighthouse": m{ + "hosts": []string{"10.0.0.1"}, + "interval": 1, + }, + }) + return a, b +}