Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add iter function for fetching hostmaps #1275

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions cert/cert_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,14 +334,8 @@ func (nc *certificateV1) Copy() Certificate {
copy(c.signature, nc.signature)
copy(c.details.Groups, nc.details.Groups)
copy(c.details.PublicKey, nc.details.PublicKey)

for i, p := range nc.details.Ips {
c.details.Ips[i] = p
}

for i, p := range nc.details.Subnets {
c.details.Subnets[i] = p
}
copy(c.details.Ips, nc.details.Ips)
copy(c.details.Subnets, nc.details.Subnets)

return c
}
Expand Down
45 changes: 45 additions & 0 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nebula

import (
"context"
"iter"
"net/netip"
"os"
"os/signal"
Expand Down Expand Up @@ -120,6 +121,15 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
}
}

// ListHostmapHostsIter returns an iter with details about the actual or pending (handshaking) hostmap by vpn ip
func (c *Control) ListHostmapHostsIter(pendingMap bool) iter.Seq[*ControlHostInfo] {
if pendingMap {
return listHostMapHostsIter(c.f.handshakeManager)
} else {
return listHostMapHostsIter(c.f.hostMap)
}
}

// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
if pendingMap {
Expand All @@ -129,6 +139,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
}
}

// ListHostmapIndexesIter returns an iter with details about the actual or pending (handshaking) hostmap by local index id
func (c *Control) ListHostmapIndexesIter(pendingMap bool) iter.Seq[*ControlHostInfo] {
if pendingMap {
return listHostMapIndexesIter(c.f.handshakeManager)
} else {
return listHostMapIndexesIter(c.f.hostMap)
}
}

// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
if c.f.myVpnNet.Addr() == vpnIp {
Expand Down Expand Up @@ -305,6 +324,19 @@ func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
return hosts
}

func listHostMapHostsIter(hl controlHostLister) iter.Seq[*ControlHostInfo] {
pr := hl.GetPreferredRanges()

return iter.Seq[*ControlHostInfo](func(yield func(*ControlHostInfo) bool) {
hl.ForEachVpnIp(func(hostinfo *HostInfo) {
host := copyHostInfo(hostinfo, pr)
if !yield(&host) {
return // Stop iteration early if yield returns false
}
})
})
}

func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges()
Expand All @@ -313,3 +345,16 @@ func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
})
return hosts
}

func listHostMapIndexesIter(hl controlHostLister) iter.Seq[*ControlHostInfo] {
pr := hl.GetPreferredRanges()

return iter.Seq[*ControlHostInfo](func(yield func(*ControlHostInfo) bool) {
hl.ForEachIndex(func(hostinfo *HostInfo) {
host := copyHostInfo(hostinfo, pr)
if !yield(&host) {
return // Stop iteration early if yield returns false
}
})
})
}
88 changes: 88 additions & 0 deletions control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,94 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
})
}

func TestListHostMapHostsIter(t *testing.T) {
l := logrus.New()
hm := newHostMap(l, netip.Prefix{})
hm.preferredRanges.Store(&[]netip.Prefix{})

hosts := []struct {
vpnIp netip.Addr
remoteAddr netip.AddrPort
localIndexId uint32
remoteIndexId uint32
}{
{vpnIp: netip.MustParseAddr("0.0.0.2"), remoteAddr: netip.MustParseAddrPort("0.0.0.101:4445"), localIndexId: 202, remoteIndexId: 201},
{vpnIp: netip.MustParseAddr("0.0.0.3"), remoteAddr: netip.MustParseAddrPort("0.0.0.102:4446"), localIndexId: 203, remoteIndexId: 202},
{vpnIp: netip.MustParseAddr("0.0.0.4"), remoteAddr: netip.MustParseAddrPort("0.0.0.103:4447"), localIndexId: 204, remoteIndexId: 203},
}

for _, h := range hosts {
hm.unlockedAddHostInfo(&HostInfo{
remote: h.remoteAddr,
ConnectionState: &ConnectionState{
peerCert: nil,
},
localIndexId: h.localIndexId,
remoteIndexId: h.remoteIndexId,
vpnIp: h.vpnIp,
}, &Interface{})
}

iter := listHostMapHostsIter(hm)
var results []ControlHostInfo

for h := range iter {
results = append(results, *h)
}

assert.Equal(t, len(hosts), len(results), "expected number of hosts in iterator")
for i, h := range hosts {
assert.Equal(t, h.vpnIp, results[i].VpnIp)
assert.Equal(t, h.localIndexId, results[i].LocalIndex)
assert.Equal(t, h.remoteIndexId, results[i].RemoteIndex)
assert.Equal(t, h.remoteAddr, results[i].CurrentRemote)
}
}

func TestListHostMapIndexesIter(t *testing.T) {
l := logrus.New()
hm := newHostMap(l, netip.Prefix{})
hm.preferredRanges.Store(&[]netip.Prefix{})

hosts := []struct {
vpnIp netip.Addr
remoteAddr netip.AddrPort
localIndexId uint32
remoteIndexId uint32
}{
{vpnIp: netip.MustParseAddr("0.0.0.2"), remoteAddr: netip.MustParseAddrPort("0.0.0.101:4445"), localIndexId: 202, remoteIndexId: 201},
{vpnIp: netip.MustParseAddr("0.0.0.3"), remoteAddr: netip.MustParseAddrPort("0.0.0.102:4446"), localIndexId: 203, remoteIndexId: 202},
{vpnIp: netip.MustParseAddr("0.0.0.4"), remoteAddr: netip.MustParseAddrPort("0.0.0.103:4447"), localIndexId: 204, remoteIndexId: 203},
}

for _, h := range hosts {
hm.unlockedAddHostInfo(&HostInfo{
remote: h.remoteAddr,
ConnectionState: &ConnectionState{
peerCert: nil,
},
localIndexId: h.localIndexId,
remoteIndexId: h.remoteIndexId,
vpnIp: h.vpnIp,
}, &Interface{})
}

iter := listHostMapIndexesIter(hm)
var results []ControlHostInfo

for h := range iter {
results = append(results, *h)
}

assert.Equal(t, len(hosts), len(results), "expected number of hosts in iterator")
for i, h := range hosts {
assert.Equal(t, h.vpnIp, results[i].VpnIp)
assert.Equal(t, h.localIndexId, results[i].LocalIndex)
assert.Equal(t, h.remoteIndexId, results[i].RemoteIndex)
assert.Equal(t, h.remoteAddr, results[i].CurrentRemote)
}
}

func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
val := reflect.ValueOf(actualStruct).Elem()
fields := make([]string, val.NumField())
Expand Down