Skip to content

Commit

Permalink
Limit how often a busy tunnel can requery the lighthouse
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Jul 31, 2023
1 parent 38e56a4 commit 102c56f
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 4 deletions.
10 changes: 10 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"math"
"os"
"os/signal"
"path/filepath"
Expand Down Expand Up @@ -236,6 +237,15 @@ func (c *C) GetInt(k string, d int) int {
return v
}

// GetUint32 will get the uint32 for k or return the default d if not found or invalid
func (c *C) GetUint32(k string, d uint32) uint32 {
r := c.GetInt(k, int(d))
if uint64(r) > uint64(math.MaxUint32) {
return d
}
return uint32(r)
}

// GetBool will get the bool for k or return the default d if not found or invalid
func (c *C) GetBool(k string, d bool) bool {
r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
Expand Down
19 changes: 15 additions & 4 deletions hostmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ import (
)

// const ProbeLen = 100
const PromoteEvery = 1000
const ReQueryEvery = 5000
const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse
const defaultReQueryWait = 1 * 60 // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
const MaxRemotes = 10

// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
Expand Down Expand Up @@ -215,6 +216,10 @@ type HostInfo struct {
remoteCidr *cidr.Tree4
relayState RelayState

// nextLHQuery is the earliest we can ask the lighthouse for new information.
// This is used to limit lighthouse re-queries in chatty clients
nextLHQuery atomic.Int64

// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
// with a handshake
Expand Down Expand Up @@ -535,7 +540,7 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
c := i.promoteCounter.Add(1)
if c%PromoteEvery == 0 {
if c%ifce.tryPromoteEvery.Load() == 0 {
// The lock here is currently protecting i.remote access
i.RLock()
remote := i.remote
Expand Down Expand Up @@ -563,7 +568,13 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
}

// Re query our lighthouses for new remotes occasionally
if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
if c%ifce.reQueryEvery.Load() == 0 && ifce.lightHouse != nil {
now := time.Now().Unix()
if now < i.nextLHQuery.Load() {
return
}

i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
ifce.lightHouse.QueryServer(i.vpnIp, ifce)
}
}
Expand Down
33 changes: 33 additions & 0 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ type InterfaceConfig struct {
relayManager *relayManager
punchy *Punchy

tryPromoteEvery uint32
reQueryEvery uint32
reQueryWait int64

ConntrackCacheTimeout time.Duration
l *logrus.Logger
}
Expand All @@ -72,6 +76,10 @@ type Interface struct {
closed atomic.Bool
relayManager *relayManager

tryPromoteEvery atomic.Uint32
reQueryEvery atomic.Uint32
reQueryWait atomic.Int64

sendRecvErrorConfig sendRecvErrorConfig

// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
Expand Down Expand Up @@ -186,6 +194,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l,
}

ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(c.reQueryWait)

ifce.certState.Store(c.certState)
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)

Expand Down Expand Up @@ -287,6 +299,7 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadCertKey)
c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.reloadSendRecvError)
c.RegisterReloadCallback(f.reloadMisc)
for _, udpConn := range f.writers {
c.RegisterReloadCallback(udpConn.ReloadConfig)
}
Expand Down Expand Up @@ -389,6 +402,26 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
}
}

func (f *Interface) reloadMisc(c *config.C) {
if c.HasChanged("counters.try_promote") {
n := c.GetUint32("counters.try_promote", defaultPromoteEvery)
f.tryPromoteEvery.Store(n)
f.l.Info("counters.try_promote has changed")
}

if c.HasChanged("counters.requery_every_packets") {
n := c.GetUint32("counters.requery_every_packets", defaultReQueryEvery)
f.reQueryEvery.Store(n)
f.l.Info("counters.requery_every_packets")
}

if c.HasChanged("timers.requery_wait_seconds") {
n := c.GetInt("timers.requery_wait_seconds", defaultReQueryWait)
f.reQueryWait.Store(int64(n))
f.l.Info("timers.requery_wait_seconds")
}
}

func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
ticker := time.NewTicker(i)
defer ticker.Stop()
Expand Down
4 changes: 4 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg

checkInterval := c.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)

ifConfig := &InterfaceConfig{
HostMap: hostMap,
Inside: tun,
Expand All @@ -282,6 +283,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
lightHouse: lightHouse,
checkInterval: time.Second * time.Duration(checkInterval),
pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery),
reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
reQueryWait: int64(c.GetInt("timers.requery_wait_seconds", defaultReQueryWait)),
DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
DropMulticast: c.GetBool("tun.drop_multicast", false),
routines: routines,
Expand Down

0 comments on commit 102c56f

Please sign in to comment.