From 70dfb71b6c5bdf1688ea9dd5a30d25af2a423f17 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Mon, 9 Sep 2024 10:53:46 -0400 Subject: [PATCH] allow "lighthouse DNS" to be run on non-lighthouses, so hosts can see their own hostmap --- dns_server.go | 34 +++++++++++++++++++++-- dns_server_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++---- main.go | 12 +++----- 3 files changed, 100 insertions(+), 15 deletions(-) diff --git a/dns_server.go b/dns_server.go index 5fea65c47..6f64e737c 100644 --- a/dns_server.go +++ b/dns_server.go @@ -144,11 +144,41 @@ func getDnsServerAddr(c *config.C) string { if dnsHost == "[::]" { dnsHost = "::" } + return dnsHost +} + +func getDnsServerAddrPort(c *config.C) string { + dnsHost := getDnsServerAddr(c) return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) } +func shouldServeDns(c *config.C) (bool, error) { + if !c.GetBool("lighthouse.serve_dns", false) { + return false, nil + } + + dnsHostStr := getDnsServerAddr(c) + if dnsHostStr == "" { //setting an ip address is required + return false, fmt.Errorf("no DNS server IP address set") + } + + if c.GetBool("lighthouse.am_lighthouse", false) { + return true, nil + } + + dnsHost, err := netip.ParseAddr(dnsHostStr) + if err != nil { + return false, fmt.Errorf("failed to parse lighthouse.dns.host(%s) %v", dnsHostStr, err) + } + if !dnsHost.IsLoopback() { + return false, fmt.Errorf("lighthouse.dns.host(%s) must be loopback on non-lighthouses", dnsHostStr) + } + + return true, nil +} + func startDns(l *logrus.Logger, c *config.C) { - dnsAddr = getDnsServerAddr(c) + dnsAddr = getDnsServerAddrPort(c) dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder") err := dnsServer.ListenAndServe() @@ -159,7 +189,7 @@ func startDns(l *logrus.Logger, c *config.C) { } func reloadDns(l *logrus.Logger, c *config.C) { - if dnsAddr == getDnsServerAddr(c) { + if dnsAddr == getDnsServerAddrPort(c) { l.Debug("No DNS server config change detected") return } diff --git a/dns_server_test.go b/dns_server_test.go index 69f6ae84f..70f4d55f0 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -20,7 +20,7 @@ func TestParsequery(t *testing.T) { //parseQuery(m) } -func Test_getDnsServerAddr(t *testing.T) { +func Test_getDnsServerAddrPort(t *testing.T) { c := config.NewC(nil) c.Settings["lighthouse"] = map[interface{}]interface{}{ @@ -29,7 +29,7 @@ func Test_getDnsServerAddr(t *testing.T) { "port": "1", }, } - assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c)) + assert.Equal(t, "0.0.0.0:1", getDnsServerAddrPort(c)) c.Settings["lighthouse"] = map[interface{}]interface{}{ "dns": map[interface{}]interface{}{ @@ -37,7 +37,7 @@ func Test_getDnsServerAddr(t *testing.T) { "port": "1", }, } - assert.Equal(t, "[::]:1", getDnsServerAddr(c)) + assert.Equal(t, "[::]:1", getDnsServerAddrPort(c)) c.Settings["lighthouse"] = map[interface{}]interface{}{ "dns": map[interface{}]interface{}{ @@ -45,7 +45,7 @@ func Test_getDnsServerAddr(t *testing.T) { "port": "1", }, } - assert.Equal(t, "[::]:1", getDnsServerAddr(c)) + assert.Equal(t, "[::]:1", getDnsServerAddrPort(c)) // Make sure whitespace doesn't mess us up c.Settings["lighthouse"] = map[interface{}]interface{}{ @@ -54,5 +54,64 @@ func Test_getDnsServerAddr(t *testing.T) { "port": "1", }, } - assert.Equal(t, "[::]:1", getDnsServerAddr(c)) + assert.Equal(t, "[::]:1", getDnsServerAddrPort(c)) +} + +func Test_shouldServeDns(t *testing.T) { + c := config.NewC(nil) + notLoopback := map[interface{}]interface{}{"host": "0.0.0.0", "port": "1"} + yesLoopbackv4 := map[interface{}]interface{}{"host": "127.0.0.2", "port": "1"} + yesLoopbackv6 := map[interface{}]interface{}{"host": "::1", "port": "1"} + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "serve_dns": false, + } + serveDns, err := shouldServeDns(c) + assert.NoError(t, err) + assert.False(t, serveDns) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "am_lighthouse": true, + "serve_dns": true, + } + serveDns, err = shouldServeDns(c) + assert.Error(t, err) + assert.False(t, serveDns) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "am_lighthouse": true, + "serve_dns": true, + "dns": notLoopback, + } + serveDns, err = shouldServeDns(c) + assert.NoError(t, err) + assert.True(t, serveDns) + + //non-lighthouses must do DNS on loopback + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "am_lighthouse": false, + "serve_dns": true, + "dns": notLoopback, + } + serveDns, err = shouldServeDns(c) + assert.Error(t, err) + assert.False(t, serveDns) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "am_lighthouse": false, + "serve_dns": true, + "dns": yesLoopbackv4, + } + serveDns, err = shouldServeDns(c) + assert.NoError(t, err) + assert.True(t, serveDns) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "am_lighthouse": false, + "serve_dns": true, + "dns": yesLoopbackv6, + } + serveDns, err = shouldServeDns(c) + assert.NoError(t, err) + assert.True(t, serveDns) } diff --git a/main.go b/main.go index c6edc9133..67855c25c 100644 --- a/main.go +++ b/main.go @@ -225,13 +225,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - serveDns := false - if c.GetBool("lighthouse.serve_dns", false) { - if c.GetBool("lighthouse.am_lighthouse", false) { - serveDns = true - } else { - l.Warn("DNS server refusing to run because this host is not a lighthouse.") - } + serveDns, dnsErr := shouldServeDns(c) + if dnsErr != nil { + l.Warnf("failed to configure DNS server: %v", dnsErr) } checkInterval := c.GetInt("timers.connection_alive_interval", 5) @@ -311,7 +307,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg // Start DNS server last to allow using the nebula IP as lighthouse.dns.host var dnsStart func() - if lightHouse.amLighthouse && serveDns { + if serveDns { l.Debugln("Starting dns server") dnsStart = dnsMain(l, hostMap, c) }