Skip to content

Commit

Permalink
Add iterators for fetching hostmaps
Browse files Browse the repository at this point in the history
  • Loading branch information
maggie44 committed Nov 16, 2024
1 parent 8bbcd07 commit 7f6affe
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 0 deletions.
45 changes: 45 additions & 0 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nebula

import (
"context"
"iter"
"net/netip"
"os"
"os/signal"
Expand Down Expand Up @@ -120,6 +121,15 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
}
}

// ListHostmapHostsIter returns an iter with details about the actual or pending (handshaking) hostmap by vpn ip
func (c *Control) ListHostmapHostsIter(pendingMap bool) iter.Seq[*ControlHostInfo] {
if pendingMap {
return listHostMapHostsIter(c.f.handshakeManager)
} else {
return listHostMapHostsIter(c.f.hostMap)
}
}

// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
if pendingMap {
Expand All @@ -129,6 +139,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
}
}

// ListHostmapIndexesIter returns an iter with details about the actual or pending (handshaking) hostmap by local index id
func (c *Control) ListHostmapIndexesIter(pendingMap bool) iter.Seq[*ControlHostInfo] {
if pendingMap {
return listHostMapIndexesIter(c.f.handshakeManager)
} else {
return listHostMapIndexesIter(c.f.hostMap)
}
}

// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
if c.f.myVpnNet.Addr() == vpnIp {
Expand Down Expand Up @@ -305,6 +324,19 @@ func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
return hosts
}

func listHostMapHostsIter(hl controlHostLister) iter.Seq[*ControlHostInfo] {
pr := hl.GetPreferredRanges()

return iter.Seq[*ControlHostInfo](func(yield func(*ControlHostInfo) bool) {
hl.ForEachVpnIp(func(hostinfo *HostInfo) {
host := copyHostInfo(hostinfo, pr)
if !yield(&host) { // Pass a pointer to host here
return // Stop iteration early if yield returns false
}
})
})
}

func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges()
Expand All @@ -313,3 +345,16 @@ func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
})
return hosts
}

func listHostMapIndexesIter(hl controlHostLister) iter.Seq[*ControlHostInfo] {
pr := hl.GetPreferredRanges()

return iter.Seq[*ControlHostInfo](func(yield func(*ControlHostInfo) bool) {
hl.ForEachIndex(func(hostinfo *HostInfo) {
host := copyHostInfo(hostinfo, pr)
if !yield(&host) {
return // Stop iteration early if yield returns false
}
})
})
}
88 changes: 88 additions & 0 deletions control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,94 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
})
}

func TestListHostMapHostsIter(t *testing.T) {
l := logrus.New()
hm := newHostMap(l, netip.Prefix{})
hm.preferredRanges.Store(&[]netip.Prefix{})

hosts := []struct {
vpnIp netip.Addr
remoteAddr netip.AddrPort
localIndexId uint32
remoteIndexId uint32
}{
{vpnIp: netip.MustParseAddr("0.0.0.2"), remoteAddr: netip.MustParseAddrPort("0.0.0.101:4445"), localIndexId: 202, remoteIndexId: 201},
{vpnIp: netip.MustParseAddr("0.0.0.3"), remoteAddr: netip.MustParseAddrPort("0.0.0.102:4446"), localIndexId: 203, remoteIndexId: 202},
{vpnIp: netip.MustParseAddr("0.0.0.4"), remoteAddr: netip.MustParseAddrPort("0.0.0.103:4447"), localIndexId: 204, remoteIndexId: 203},
}

for _, h := range hosts {
hm.unlockedAddHostInfo(&HostInfo{
remote: h.remoteAddr,
ConnectionState: &ConnectionState{
peerCert: nil,
},
localIndexId: h.localIndexId,
remoteIndexId: h.remoteIndexId,
vpnIp: h.vpnIp,
}, &Interface{})
}

iter := listHostMapHostsIter(hm)
var results []ControlHostInfo

for h := range iter {
results = append(results, *h)
}

assert.Equal(t, len(hosts), len(results), "expected number of hosts in iterator")
for i, h := range hosts {
assert.Equal(t, h.vpnIp, results[i].VpnIp)
assert.Equal(t, h.localIndexId, results[i].LocalIndex)
assert.Equal(t, h.remoteIndexId, results[i].RemoteIndex)
assert.Equal(t, h.remoteAddr, results[i].CurrentRemote)
}
}

func TestListHostMapIndexesIter(t *testing.T) {
l := logrus.New()
hm := newHostMap(l, netip.Prefix{})
hm.preferredRanges.Store(&[]netip.Prefix{})

hosts := []struct {
vpnIp netip.Addr
remoteAddr netip.AddrPort
localIndexId uint32
remoteIndexId uint32
}{
{vpnIp: netip.MustParseAddr("0.0.0.2"), remoteAddr: netip.MustParseAddrPort("0.0.0.101:4445"), localIndexId: 202, remoteIndexId: 201},
{vpnIp: netip.MustParseAddr("0.0.0.3"), remoteAddr: netip.MustParseAddrPort("0.0.0.102:4446"), localIndexId: 203, remoteIndexId: 202},
{vpnIp: netip.MustParseAddr("0.0.0.4"), remoteAddr: netip.MustParseAddrPort("0.0.0.103:4447"), localIndexId: 204, remoteIndexId: 203},
}

for _, h := range hosts {
hm.unlockedAddHostInfo(&HostInfo{
remote: h.remoteAddr,
ConnectionState: &ConnectionState{
peerCert: nil,
},
localIndexId: h.localIndexId,
remoteIndexId: h.remoteIndexId,
vpnIp: h.vpnIp,
}, &Interface{})
}

iter := listHostMapIndexesIter(hm)
var results []ControlHostInfo

for h := range iter {
results = append(results, *h)
}

assert.Equal(t, len(hosts), len(results), "expected number of hosts in iterator")
for i, h := range hosts {
assert.Equal(t, h.vpnIp, results[i].VpnIp)
assert.Equal(t, h.localIndexId, results[i].LocalIndex)
assert.Equal(t, h.remoteIndexId, results[i].RemoteIndex)
assert.Equal(t, h.remoteAddr, results[i].CurrentRemote)
}
}

func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
val := reflect.ValueOf(actualStruct).Elem()
fields := make([]string, val.NumField())
Expand Down

0 comments on commit 7f6affe

Please sign in to comment.