diff --git a/handshake_manager_test.go b/handshake_manager_test.go index c1898384a..7172a6384 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -96,3 +96,7 @@ func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {} func (mw *mockEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo { return nil } + +func (mw *mockEncWriter) GetCertState() *CertState { + return &CertState{defaultVersion: cert.Version2} +} diff --git a/hostmap.go b/hostmap.go index e3817e7c6..824d72251 100644 --- a/hostmap.go +++ b/hostmap.go @@ -352,31 +352,31 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { h := hm.Hosts[addr] for h != nil { if h == hostinfo { - hm.unlockedInnerDeleteHostInfo(h) + hm.unlockedInnerDeleteHostInfo(h, addr) } h = h.next } } } -func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo) { - primary, ok := hm.Hosts[hostinfo.vpnAddrs[0]] +func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) { + primary, ok := hm.Hosts[addr] if ok && primary == hostinfo { - // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it - delete(hm.Hosts, hostinfo.vpnAddrs[0]) + // The vpn addr pointer points to the same hostinfo as the local index id, we can remove it + delete(hm.Hosts, addr) if len(hm.Hosts) == 0 { hm.Hosts = map[netip.Addr]*HostInfo{} } if hostinfo.next != nil { - // We had more than 1 hostinfo at this vpnip, promote the next in the list to primary - hm.Hosts[hostinfo.vpnAddrs[0]] = hostinfo.next + // We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary + hm.Hosts[addr] = hostinfo.next // It is primary, there is no previous hostinfo now hostinfo.next.prev = nil } } else { - // Relink if we were in the middle of multiple hostinfos for this vpn ip + // Relink if we were in the middle of multiple hostinfos for this vpn addr if hostinfo.prev != nil { hostinfo.prev.next = hostinfo.next } diff --git a/interface.go b/interface.go index 9ec7d1bf7..378c35450 100644 --- a/interface.go +++ b/interface.go @@ -107,6 +107,7 @@ type EncWriter interface { SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) Handshake(vpnIp netip.Addr) GetHostInfo(vpnIp netip.Addr) *HostInfo + GetCertState() *CertState } type sendRecvErrorConfig uint8 @@ -428,6 +429,10 @@ func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo { return f.hostMap.QueryVpnAddr(vpnIp) } +func (f *Interface) GetCertState() *CertState { + return f.pki.getCertState() +} + func (f *Interface) Close() error { f.closed.Store(true) diff --git a/lighthouse.go b/lighthouse.go index 32e280b0e..af12d0cd3 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -61,11 +61,10 @@ type LightHouse struct { staticList atomic.Pointer[map[netip.Addr]struct{}] lighthouses atomic.Pointer[map[netip.Addr]struct{}] - interval atomic.Int64 - updateCancel context.CancelFunc - ifce EncWriter - nebulaPort uint32 // 32 bits because protobuf does not have a uint16 - protocolVersion atomic.Uint32 // The default protocol version to use if we can't determine which to use from the tunnel + interval atomic.Int64 + updateCancel context.CancelFunc + ifce EncWriter + nebulaPort uint32 // 32 bits because protobuf does not have a uint16 advertiseAddrs atomic.Pointer[[]netip.AddrPort] @@ -352,16 +351,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } } - v := c.GetUint32("pki.default_version", 1) - switch v { - case 1: - lh.protocolVersion.Store(1) - case 2: - lh.protocolVersion.Store(2) - default: - return fmt.Errorf("invalid version for lighthouse: %v", v) - } - return nil } @@ -750,15 +739,12 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { } // Send a query to the lighthouses and hope for the best next time - //TODO: this is not sufficient since the version depends on the certs loaded into memory as well - v := lh.protocolVersion.Load() + v := lh.ifce.GetCertState().defaultVersion msg := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{}, } - //TODO: remove this - v = 2 if v == 1 { if !addr.Is4() { lh.l.WithField("vpnAddr", addr).Error("Can't query lighthouse for v6 address using a v1 protocol") @@ -843,7 +829,7 @@ func (lh *LightHouse) SendUpdate() { } } - v := lh.protocolVersion.Load() + v := lh.ifce.GetCertState().defaultVersion msg := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ @@ -852,8 +838,6 @@ func (lh *LightHouse) SendUpdate() { }, } - //TODO: remove this - v = 2 if v == 1 { var relays []uint32 for _, r := range lh.GetRelaysForMe() { @@ -1042,11 +1026,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti found, ln, err = lhh.lh.queryAndPrepMessage(fromVpnAddrs[0], func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - //TODO: unsure which version to use. If we had access to the hostmap we could see if there is already a tunnel - // and use that version then fallback to our default configuration targetHI := lhh.lh.ifce.GetHostInfo(queryVpnIp) - useVersion = cert.Version(lhh.lh.protocolVersion.Load()) - if targetHI != nil { + if targetHI == nil { + useVersion = lhh.lh.ifce.GetCertState().defaultVersion + } else { useVersion = targetHI.GetCert().Certificate.Version() } diff --git a/lighthouse_test.go b/lighthouse_test.go index 116099773..1c1489b4f 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/gaissmai/bart" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" @@ -427,8 +428,9 @@ type testLhReply struct { } type testEncWriter struct { - lastReply testLhReply - metaFilter *NebulaMeta_MessageType + lastReply testLhReply + metaFilter *NebulaMeta_MessageType + protocolVersion cert.Version } func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { @@ -474,6 +476,10 @@ func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo { return nil } +func (tw *testEncWriter) GetCertState() *CertState { + return &CertState{defaultVersion: tw.protocolVersion} +} + // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) { if !assert.Len(t, have, len(want)) { diff --git a/pki.go b/pki.go index c4160d5a8..779a1598a 100644 --- a/pki.go +++ b/pki.go @@ -324,10 +324,24 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { } } - rawDefaultVersion := c.GetUint32("pki.default_version", 1) + if v1 == nil && v2 == nil { + return nil, errors.New("no certificates found in pki.cert") + } + + useDefaultVersion := uint32(1) + if v1 == nil { + // The only condition that requires v2 as the default is if only a v2 certificate is present + // We do this to avoid having to configure it specifically in the config file + useDefaultVersion = 2 + } + + rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion) var defaultVersion cert.Version switch rawDefaultVersion { case 1: + if v1 == nil { + return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert") + } defaultVersion = cert.Version1 case 2: defaultVersion = cert.Version2