From bdfa347252d6c657c9af2f2331d92c287d3a9db4 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 18 Oct 2024 23:27:47 -0500 Subject: [PATCH] Break apart hostmap.go --- handshake_manager.go | 14 ++ hostinfo.go | 197 ++++++++++++++++++++++++ hostmap.go | 352 ------------------------------------------- relay_state.go | 154 +++++++++++++++++++ remote_list.go | 2 + 5 files changed, 367 insertions(+), 352 deletions(-) create mode 100644 hostinfo.go create mode 100644 relay_state.go diff --git a/handshake_manager.go b/handshake_manager.go index ee1545647..b5d37d74a 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -77,6 +77,20 @@ type HandshakeHostInfo struct { hostinfo *HostInfo } +type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte) + +type cachedPacket struct { + messageType header.MessageType + messageSubType header.MessageSubType + callback packetCallback + packet []byte +} + +type cachedPacketMetrics struct { + sent metrics.Counter + dropped metrics.Counter +} + func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { if len(hh.packetStore) < 100 { tempPacket := make([]byte, len(packet)) diff --git a/hostinfo.go b/hostinfo.go new file mode 100644 index 000000000..4f755a35a --- /dev/null +++ b/hostinfo.go @@ -0,0 +1,197 @@ +package nebula + +import ( + "net/netip" + "sync/atomic" + "time" + + "github.com/gaissmai/bart" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/header" +) + +const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address +const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse +const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery +const maxRecvError = 4 + +// RoamingSuppressSeconds is how long we should prevent roaming back to the previous IP. +// This helps prevent flapping due to packets already in flight +const RoamingSuppressSeconds = 2 + +type HostInfo struct { + remote netip.AddrPort + remotes *RemoteList + promoteCounter atomic.Uint32 + ConnectionState *ConnectionState + remoteIndexId uint32 + localIndexId uint32 + vpnAddrs []netip.Addr + recvError atomic.Uint32 + + // networks are both all vpn and unsafe networks assigned to this host + networks *bart.Table[struct{}] + relayState RelayState + + // HandshakePacket records the packets used to create this hostinfo + // We need these to avoid replayed handshake packets creating new hostinfos which causes churn + HandshakePacket map[uint8][]byte + + // nextLHQuery is the earliest we can ask the lighthouse for new information. + // This is used to limit lighthouse re-queries in chatty clients + nextLHQuery atomic.Int64 + + // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH + // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like + // with a handshake + lastRebindCount int8 + + // lastHandshakeTime records the time the remote side told us about at the stage when the handshake was completed locally + // Stage 1 packet will contain it if I am a responder, stage 2 packet if I am an initiator + // This is used to avoid an attack where a handshake packet is replayed after some time + lastHandshakeTime uint64 + + lastRoam time.Time + lastRoamRemote netip.AddrPort + + // Used to track other hostinfos for this vpn ip since only 1 can be primary + // Synchronised via hostmap lock and not the hostinfo lock. + next, prev *HostInfo +} + +// TryPromoteBest handles re-querying lighthouses and probing for better paths +// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! +func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) { + c := i.promoteCounter.Add(1) + if c%ifce.tryPromoteEvery.Load() == 0 { + remote := i.remote + + // return early if we are already on a preferred remote + if remote.IsValid() { + rIP := remote.Addr() + for _, l := range preferredRanges { + if l.Contains(rIP) { + return + } + } + } + + i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) { + if remote.IsValid() && (!addr.IsValid() || !preferred) { + return + } + + // Try to send a test packet to that host, this should + // cause it to detect a roaming event and switch remotes + ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + }) + } + + // Re query our lighthouses for new remotes occasionally + if c%ifce.reQueryEvery.Load() == 0 && ifce.lightHouse != nil { + now := time.Now().UnixNano() + if now < i.nextLHQuery.Load() { + return + } + + i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) + ifce.lightHouse.QueryServer(i.vpnAddrs[0]) + } +} + +func (i *HostInfo) GetCert() *cert.CachedCertificate { + if i.ConnectionState != nil { + return i.ConnectionState.peerCert + } + return nil +} + +func (i *HostInfo) SetRemote(remote netip.AddrPort) { + // We copy here because we likely got this remote from a source that reuses the object + if i.remote != remote { + i.remote = remote + i.remotes.LearnRemote(i.vpnAddrs[0], remote) + } +} + +// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam +// time on the HostInfo will also be updated. +func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool { + if !newRemote.IsValid() { + // relays have nil udp Addrs + return false + } + currentRemote := i.remote + if !currentRemote.IsValid() { + i.SetRemote(newRemote) + return true + } + + // NOTE: We do this loop here instead of calling `isPreferred` in + // remote_list.go so that we only have to loop over preferredRanges once. + newIsPreferred := false + for _, l := range hm.GetPreferredRanges() { + // return early if we are already on a preferred remote + if l.Contains(currentRemote.Addr()) { + return false + } + + if l.Contains(newRemote.Addr()) { + newIsPreferred = true + } + } + + if newIsPreferred { + // Consider this a roaming event + i.lastRoam = time.Now() + i.lastRoamRemote = currentRemote + + i.SetRemote(newRemote) + + return true + } + + return false +} + +func (i *HostInfo) RecvErrorExceeded() bool { + if i.recvError.Add(1) >= maxRecvError { + return true + } + return true +} + +func (i *HostInfo) buildNetworks(c cert.Certificate) { + if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { + // Simple case, no CIDRTree needed + return + } + + i.networks = new(bart.Table[struct{}]) + for _, network := range c.Networks() { + i.networks.Insert(network, struct{}{}) + } + + for _, network := range c.UnsafeNetworks() { + i.networks.Insert(network, struct{}{}) + } +} + +func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { + if i == nil { + return logrus.NewEntry(l) + } + + li := l.WithField("vpnAddrs", i.vpnAddrs). + WithField("localIndex", i.localIndexId). + WithField("remoteIndex", i.remoteIndexId) + + if connState := i.ConnectionState; connState != nil { + if peerCert := connState.peerCert; peerCert != nil { + li = li.WithField("certName", peerCert.Certificate.Name()) + } + } + + return li +} diff --git a/hostmap.go b/hostmap.go index e3da64b05..6e5e79ab2 100644 --- a/hostmap.go +++ b/hostmap.go @@ -6,51 +6,16 @@ import ( "net/netip" "sync" "sync/atomic" - "time" - "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/header" ) -// const ProbeLen = 100 -const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address -const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse -const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery -const MaxRemotes = 10 -const maxRecvError = 4 - // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip // 5 allows for an initial handshake and each host pair re-handshaking twice const MaxHostInfosPerVpnIp = 5 -// How long we should prevent roaming back to the previous IP. -// This helps prevent flapping due to packets already in flight -const RoamingSuppressSeconds = 2 - -const ( - Requested = iota - PeerRequested - Established -) - -const ( - Unknowntype = iota - ForwardingType - TerminalType -) - -type Relay struct { - Type int - State int - LocalIndex uint32 - RemoteIndex uint32 - PeerAddr netip.Addr -} - type HostMap struct { sync.RWMutex //Because we concurrently read and write to our maps Indexes map[uint32]*HostInfo @@ -61,187 +26,6 @@ type HostMap struct { l *logrus.Logger } -// For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay -// struct, make a copy of an existing value, edit the fileds in the copy, and -// then store a pointer to the new copy in both realyForBy* maps. -type RelayState struct { - sync.RWMutex - - relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer - relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info - relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info -} - -func (rs *RelayState) DeleteRelay(ip netip.Addr) { - rs.Lock() - defer rs.Unlock() - delete(rs.relays, ip) -} - -func (rs *RelayState) CopyAllRelayFor() []*Relay { - rs.RLock() - defer rs.RUnlock() - ret := make([]*Relay, 0, len(rs.relayForByIdx)) - for _, r := range rs.relayForByIdx { - ret = append(ret, r) - } - return ret -} - -func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) { - rs.RLock() - defer rs.RUnlock() - r, ok := rs.relayForByAddr[addr] - return r, ok -} - -func (rs *RelayState) InsertRelayTo(ip netip.Addr) { - rs.Lock() - defer rs.Unlock() - rs.relays[ip] = struct{}{} -} - -func (rs *RelayState) CopyRelayIps() []netip.Addr { - rs.RLock() - defer rs.RUnlock() - ret := make([]netip.Addr, 0, len(rs.relays)) - for ip := range rs.relays { - ret = append(ret, ip) - } - return ret -} - -func (rs *RelayState) CopyRelayForIps() []netip.Addr { - rs.RLock() - defer rs.RUnlock() - currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr)) - for relayIp := range rs.relayForByAddr { - currentRelays = append(currentRelays, relayIp) - } - return currentRelays -} - -func (rs *RelayState) CopyRelayForIdxs() []uint32 { - rs.RLock() - defer rs.RUnlock() - ret := make([]uint32, 0, len(rs.relayForByIdx)) - for i := range rs.relayForByIdx { - ret = append(ret, i) - } - return ret -} - -func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { - rs.Lock() - defer rs.Unlock() - r, ok := rs.relayForByAddr[vpnIp] - if !ok { - return false - } - newRelay := *r - newRelay.State = Established - newRelay.RemoteIndex = remoteIdx - rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByAddr[r.PeerAddr] = &newRelay - return true -} - -func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Relay, bool) { - rs.Lock() - defer rs.Unlock() - r, ok := rs.relayForByIdx[localIdx] - if !ok { - return nil, false - } - newRelay := *r - newRelay.State = Established - newRelay.RemoteIndex = remoteIdx - rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByAddr[r.PeerAddr] = &newRelay - return &newRelay, true -} - -func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { - rs.RLock() - defer rs.RUnlock() - r, ok := rs.relayForByAddr[vpnIp] - return r, ok -} - -func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { - rs.RLock() - defer rs.RUnlock() - r, ok := rs.relayForByIdx[idx] - return r, ok -} - -func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { - rs.Lock() - defer rs.Unlock() - rs.relayForByAddr[ip] = r - rs.relayForByIdx[idx] = r -} - -type HostInfo struct { - remote netip.AddrPort - remotes *RemoteList - promoteCounter atomic.Uint32 - ConnectionState *ConnectionState - remoteIndexId uint32 - localIndexId uint32 - vpnAddrs []netip.Addr - recvError atomic.Uint32 - - // networks are both all vpn and unsafe networks assigned to this host - networks *bart.Table[struct{}] - relayState RelayState - - // HandshakePacket records the packets used to create this hostinfo - // We need these to avoid replayed handshake packets creating new hostinfos which causes churn - HandshakePacket map[uint8][]byte - - // nextLHQuery is the earliest we can ask the lighthouse for new information. - // This is used to limit lighthouse re-queries in chatty clients - nextLHQuery atomic.Int64 - - // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH - // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like - // with a handshake - lastRebindCount int8 - - // lastHandshakeTime records the time the remote side told us about at the stage when the handshake was completed locally - // Stage 1 packet will contain it if I am a responder, stage 2 packet if I am an initiator - // This is used to avoid an attack where a handshake packet is replayed after some time - lastHandshakeTime uint64 - - lastRoam time.Time - lastRoamRemote netip.AddrPort - - // Used to track other hostinfos for this vpn ip since only 1 can be primary - // Synchronised via hostmap lock and not the hostinfo lock. - next, prev *HostInfo -} - -type ViaSender struct { - relayHI *HostInfo // relayHI is the host info object of the relay - remoteIdx uint32 // remoteIdx is the index included in the header of the received packet - relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us. -} - -type cachedPacket struct { - messageType header.MessageType - messageSubType header.MessageSubType - callback packetCallback - packet []byte -} - -type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte) - -type cachedPacketMetrics struct { - sent metrics.Counter - dropped metrics.Counter -} - func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { hm := newHostMap(l) @@ -556,142 +340,6 @@ func (hm *HostMap) ForEachIndex(f controlEach) { } } -// TryPromoteBest handles re-querying lighthouses and probing for better paths -// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! -func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) { - c := i.promoteCounter.Add(1) - if c%ifce.tryPromoteEvery.Load() == 0 { - remote := i.remote - - // return early if we are already on a preferred remote - if remote.IsValid() { - rIP := remote.Addr() - for _, l := range preferredRanges { - if l.Contains(rIP) { - return - } - } - } - - i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) { - if remote.IsValid() && (!addr.IsValid() || !preferred) { - return - } - - // Try to send a test packet to that host, this should - // cause it to detect a roaming event and switch remotes - ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - }) - } - - // Re query our lighthouses for new remotes occasionally - if c%ifce.reQueryEvery.Load() == 0 && ifce.lightHouse != nil { - now := time.Now().UnixNano() - if now < i.nextLHQuery.Load() { - return - } - - i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) - ifce.lightHouse.QueryServer(i.vpnAddrs[0]) - } -} - -func (i *HostInfo) GetCert() *cert.CachedCertificate { - if i.ConnectionState != nil { - return i.ConnectionState.peerCert - } - return nil -} - -func (i *HostInfo) SetRemote(remote netip.AddrPort) { - // We copy here because we likely got this remote from a source that reuses the object - if i.remote != remote { - i.remote = remote - i.remotes.LearnRemote(i.vpnAddrs[0], remote) - } -} - -// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam -// time on the HostInfo will also be updated. -func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool { - if !newRemote.IsValid() { - // relays have nil udp Addrs - return false - } - currentRemote := i.remote - if !currentRemote.IsValid() { - i.SetRemote(newRemote) - return true - } - - // NOTE: We do this loop here instead of calling `isPreferred` in - // remote_list.go so that we only have to loop over preferredRanges once. - newIsPreferred := false - for _, l := range hm.GetPreferredRanges() { - // return early if we are already on a preferred remote - if l.Contains(currentRemote.Addr()) { - return false - } - - if l.Contains(newRemote.Addr()) { - newIsPreferred = true - } - } - - if newIsPreferred { - // Consider this a roaming event - i.lastRoam = time.Now() - i.lastRoamRemote = currentRemote - - i.SetRemote(newRemote) - - return true - } - - return false -} - -func (i *HostInfo) RecvErrorExceeded() bool { - if i.recvError.Add(1) >= maxRecvError { - return true - } - return true -} - -func (i *HostInfo) buildNetworks(c cert.Certificate) { - if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { - // Simple case, no CIDRTree needed - return - } - - i.networks = new(bart.Table[struct{}]) - for _, network := range c.Networks() { - i.networks.Insert(network, struct{}{}) - } - - for _, network := range c.UnsafeNetworks() { - i.networks.Insert(network, struct{}{}) - } -} - -func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { - if i == nil { - return logrus.NewEntry(l) - } - - li := l.WithField("vpnAddrs", i.vpnAddrs). - WithField("localIndex", i.localIndexId). - WithField("remoteIndex", i.remoteIndexId) - - if connState := i.ConnectionState; connState != nil { - if peerCert := connState.peerCert; peerCert != nil { - li = li.WithField("certName", peerCert.Certificate.Name()) - } - } - - return li -} - // Utility functions func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { diff --git a/relay_state.go b/relay_state.go new file mode 100644 index 000000000..100c6e707 --- /dev/null +++ b/relay_state.go @@ -0,0 +1,154 @@ +package nebula + +import ( + "net/netip" + "sync" +) + +const ( + Requested = iota + PeerRequested + Established +) + +const ( + Unknowntype = iota + ForwardingType + TerminalType +) + +// RelayState describes an established relay between 3 parties +// for synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay +// struct, make a copy of an existing value, edit the fileds in the copy, and +// then store a pointer to the new copy in both realyForBy* maps. +type RelayState struct { + sync.RWMutex + + relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer + relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info + relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info +} + +type Relay struct { + Type int + State int + LocalIndex uint32 + RemoteIndex uint32 + PeerAddr netip.Addr +} + +type ViaSender struct { + relayHI *HostInfo // relayHI is the host info object of the relay + remoteIdx uint32 // remoteIdx is the index included in the header of the received packet + relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us. +} + +func (rs *RelayState) DeleteRelay(ip netip.Addr) { + rs.Lock() + defer rs.Unlock() + delete(rs.relays, ip) +} + +func (rs *RelayState) CopyAllRelayFor() []*Relay { + rs.RLock() + defer rs.RUnlock() + ret := make([]*Relay, 0, len(rs.relayForByIdx)) + for _, r := range rs.relayForByIdx { + ret = append(ret, r) + } + return ret +} + +func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) { + rs.RLock() + defer rs.RUnlock() + r, ok := rs.relayForByAddr[addr] + return r, ok +} + +func (rs *RelayState) InsertRelayTo(ip netip.Addr) { + rs.Lock() + defer rs.Unlock() + rs.relays[ip] = struct{}{} +} + +func (rs *RelayState) CopyRelayIps() []netip.Addr { + rs.RLock() + defer rs.RUnlock() + ret := make([]netip.Addr, 0, len(rs.relays)) + for ip := range rs.relays { + ret = append(ret, ip) + } + return ret +} + +func (rs *RelayState) CopyRelayForIps() []netip.Addr { + rs.RLock() + defer rs.RUnlock() + currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr)) + for relayIp := range rs.relayForByAddr { + currentRelays = append(currentRelays, relayIp) + } + return currentRelays +} + +func (rs *RelayState) CopyRelayForIdxs() []uint32 { + rs.RLock() + defer rs.RUnlock() + ret := make([]uint32, 0, len(rs.relayForByIdx)) + for i := range rs.relayForByIdx { + ret = append(ret, i) + } + return ret +} + +func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { + rs.Lock() + defer rs.Unlock() + r, ok := rs.relayForByAddr[vpnIp] + if !ok { + return false + } + newRelay := *r + newRelay.State = Established + newRelay.RemoteIndex = remoteIdx + rs.relayForByIdx[r.LocalIndex] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay + return true +} + +func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Relay, bool) { + rs.Lock() + defer rs.Unlock() + r, ok := rs.relayForByIdx[localIdx] + if !ok { + return nil, false + } + newRelay := *r + newRelay.State = Established + newRelay.RemoteIndex = remoteIdx + rs.relayForByIdx[r.LocalIndex] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay + return &newRelay, true +} + +func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { + rs.RLock() + defer rs.RUnlock() + r, ok := rs.relayForByAddr[vpnIp] + return r, ok +} + +func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { + rs.RLock() + defer rs.RUnlock() + r, ok := rs.relayForByIdx[idx] + return r, ok +} + +func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { + rs.Lock() + defer rs.Unlock() + rs.relayForByAddr[ip] = r + rs.relayForByIdx[idx] = r +} diff --git a/remote_list.go b/remote_list.go index 4a9b50f13..dd7659f6b 100644 --- a/remote_list.go +++ b/remote_list.go @@ -13,6 +13,8 @@ import ( "github.com/sirupsen/logrus" ) +const MaxRemotes = 10 + // forEachFunc is used to benefit folks that want to do work inside the lock type forEachFunc func(addr netip.AddrPort, preferred bool)