From 100492904f5db7c1f80bbfae3e94d52ca52b286c Mon Sep 17 00:00:00 2001 From: maggie44 <64841595+maggie44@users.noreply.github.com> Date: Sat, 16 Nov 2024 12:31:03 +0000 Subject: [PATCH] Add iterators for fetching hostmaps --- control.go | 45 +++++++++++++++++++++++++ control_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+) diff --git a/control.go b/control.go index 26159845c..ad4a806e3 100644 --- a/control.go +++ b/control.go @@ -2,6 +2,7 @@ package nebula import ( "context" + "iter" "net/netip" "os" "os/signal" @@ -120,6 +121,15 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo { } } +// ListHostmapHostsIter returns an iter with details about the actual or pending (handshaking) hostmap by vpn ip +func (c *Control) ListHostmapHostsIter(pendingMap bool) iter.Seq[*ControlHostInfo] { + if pendingMap { + return listHostMapHostsIter(c.f.handshakeManager) + } else { + return listHostMapHostsIter(c.f.hostMap) + } +} + // ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { if pendingMap { @@ -129,6 +139,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { } } +// ListHostmapIndexesIter returns an iter with details about the actual or pending (handshaking) hostmap by local index id +func (c *Control) ListHostmapIndexesIter(pendingMap bool) iter.Seq[*ControlHostInfo] { + if pendingMap { + return listHostMapIndexesIter(c.f.handshakeManager) + } else { + return listHostMapIndexesIter(c.f.hostMap) + } +} + // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { if c.f.myVpnNet.Addr() == vpnIp { @@ -305,6 +324,19 @@ func listHostMapHosts(hl controlHostLister) []ControlHostInfo { return hosts } +func listHostMapHostsIter(hl controlHostLister) iter.Seq[*ControlHostInfo] { + pr := hl.GetPreferredRanges() + + return iter.Seq[*ControlHostInfo](func(yield func(*ControlHostInfo) bool) { + hl.ForEachVpnIp(func(hostinfo *HostInfo) { + host := copyHostInfo(hostinfo, pr) + if !yield(&host) { // Pass a pointer to host here + return // Stop iteration early if yield returns false + } + }) + }) +} + func listHostMapIndexes(hl controlHostLister) []ControlHostInfo { hosts := make([]ControlHostInfo, 0) pr := hl.GetPreferredRanges() @@ -313,3 +345,16 @@ func listHostMapIndexes(hl controlHostLister) []ControlHostInfo { }) return hosts } + +func listHostMapIndexesIter(hl controlHostLister) iter.Seq[*ControlHostInfo] { + pr := hl.GetPreferredRanges() + + return iter.Seq[*ControlHostInfo](func(yield func(*ControlHostInfo) bool) { + hl.ForEachIndex(func(hostinfo *HostInfo) { + host := copyHostInfo(hostinfo, pr) + if !yield(&host) { + return // Stop iteration early if yield returns false + } + }) + }) +} diff --git a/control_test.go b/control_test.go index fdfc0a57e..1cb661f95 100644 --- a/control_test.go +++ b/control_test.go @@ -110,6 +110,94 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }) } +func TestListHostMapHostsIter(t *testing.T) { + l := logrus.New() + hm := newHostMap(l, netip.Prefix{}) + hm.preferredRanges.Store(&[]netip.Prefix{}) + + hosts := []struct { + vpnIp netip.Addr + remoteAddr netip.AddrPort + localIndexId uint32 + remoteIndexId uint32 + }{ + {vpnIp: netip.MustParseAddr("0.0.0.2"), remoteAddr: netip.MustParseAddrPort("0.0.0.101:4445"), localIndexId: 202, remoteIndexId: 201}, + {vpnIp: netip.MustParseAddr("0.0.0.3"), remoteAddr: netip.MustParseAddrPort("0.0.0.102:4446"), localIndexId: 203, remoteIndexId: 202}, + {vpnIp: netip.MustParseAddr("0.0.0.4"), remoteAddr: netip.MustParseAddrPort("0.0.0.103:4447"), localIndexId: 204, remoteIndexId: 203}, + } + + for _, h := range hosts { + hm.unlockedAddHostInfo(&HostInfo{ + remote: h.remoteAddr, + ConnectionState: &ConnectionState{ + peerCert: nil, + }, + localIndexId: h.localIndexId, + remoteIndexId: h.remoteIndexId, + vpnIp: h.vpnIp, + }, &Interface{}) + } + + iter := listHostMapHostsIter(hm) + var results []ControlHostInfo + + for h := range iter { + results = append(results, *h) + } + + assert.Equal(t, len(hosts), len(results), "expected number of hosts in iterator") + for i, h := range hosts { + assert.Equal(t, h.vpnIp, results[i].VpnIp) + assert.Equal(t, h.localIndexId, results[i].LocalIndex) + assert.Equal(t, h.remoteIndexId, results[i].RemoteIndex) + assert.Equal(t, h.remoteAddr, results[i].CurrentRemote) + } +} + +func TestListHostMapIndexesIter(t *testing.T) { + l := logrus.New() + hm := newHostMap(l, netip.Prefix{}) + hm.preferredRanges.Store(&[]netip.Prefix{}) + + hosts := []struct { + vpnIp netip.Addr + remoteAddr netip.AddrPort + localIndexId uint32 + remoteIndexId uint32 + }{ + {vpnIp: netip.MustParseAddr("0.0.0.2"), remoteAddr: netip.MustParseAddrPort("0.0.0.101:4445"), localIndexId: 202, remoteIndexId: 201}, + {vpnIp: netip.MustParseAddr("0.0.0.3"), remoteAddr: netip.MustParseAddrPort("0.0.0.102:4446"), localIndexId: 203, remoteIndexId: 202}, + {vpnIp: netip.MustParseAddr("0.0.0.4"), remoteAddr: netip.MustParseAddrPort("0.0.0.103:4447"), localIndexId: 204, remoteIndexId: 203}, + } + + for _, h := range hosts { + hm.unlockedAddHostInfo(&HostInfo{ + remote: h.remoteAddr, + ConnectionState: &ConnectionState{ + peerCert: nil, + }, + localIndexId: h.localIndexId, + remoteIndexId: h.remoteIndexId, + vpnIp: h.vpnIp, + }, &Interface{}) + } + + iter := listHostMapIndexesIter(hm) + var results []ControlHostInfo + + for h := range iter { + results = append(results, *h) + } + + assert.Equal(t, len(hosts), len(results), "expected number of hosts in iterator") + for i, h := range hosts { + assert.Equal(t, h.vpnIp, results[i].VpnIp) + assert.Equal(t, h.localIndexId, results[i].LocalIndex) + assert.Equal(t, h.remoteIndexId, results[i].RemoteIndex) + assert.Equal(t, h.remoteAddr, results[i].CurrentRemote) + } +} + func assertFields(t *testing.T, expected []string, actualStruct interface{}) { val := reflect.ValueOf(actualStruct).Elem() fields := make([]string, val.NumField())