Skip to content

Commit

Permalink
Resolve some todos (#1274)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus authored Nov 15, 2024
1 parent 5380fef commit 9d310e7
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 101 deletions.
23 changes: 12 additions & 11 deletions connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,17 +426,17 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.

//TODO: current.vpnIp should become an array of vpnIps
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
// vpn addr is static across all tunnels for this host pair so lets
// use that to determine if we should consider swapping.
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
// The remotes vpn ip is lower than mine. I will not flip.
// Their primary vpn addr is less than mine. Do not swap.
return false
}

//TODO: we should favor v2 over v1 certificates if configured to send them

crt := n.intf.pki.getCertificate(current.ConnectionState.myCert.Version())
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
}

Expand Down Expand Up @@ -495,13 +495,14 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}

func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
crt := n.intf.pki.getCertificate(hostinfo.ConnectionState.myCert.Version())
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), crt.Signature()) {
cs := n.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
// The current tunnel is using the latest certificate and version, no need to rehandshake.
return
}

//TODO: we should favor v2 over v1 certificates if configured to send them

n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
Expand Down
14 changes: 5 additions & 9 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
if found {
//TODO: we might have 2 certs....
//TODO: this should return our latest version cert
return c.f.pki.getDefaultCertificate().Copy()
// Only returning the default certificate since its impossible
// for any other host but ourselves to have more than 1
return c.f.pki.getCertState().GetDefaultCertificate().Copy()
}
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
if hi == nil {
Expand Down Expand Up @@ -228,13 +228,9 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
// the int returned is a count of tunnels closed
func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
lighthouses := c.f.lightHouse.GetLighthouses()

shutdown := func(h *HostInfo) {
if excludeLighthouses {
if _, ok := lighthouses[h.vpnAddrs[0]]; ok {
return
}
if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
return
}
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h)
Expand Down
1 change: 0 additions & 1 deletion firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
)

type FirewallInterface interface {
//TODO: name these better addr, localAddr. Are they vpnAddrs?
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
}

Expand Down
2 changes: 1 addition & 1 deletion interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
udpStats()
certExpirationGauge.Update(int64(f.pki.getDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second))
certExpirationGauge.Update(int64(f.pki.getCertState().GetDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second))
//TODO: we should also report the default certificate version
}
}
Expand Down
5 changes: 3 additions & 2 deletions overlay/tun_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,12 @@ func (t *winTun) Close() error {
luid := winipcfg.LUID(t.tun.LUID())
_ = luid.FlushRoutes(windows.AF_INET)
_ = luid.FlushIPAddresses(windows.AF_INET)
/* We don't support IPV6 yet

_ = luid.FlushRoutes(windows.AF_INET6)
_ = luid.FlushIPAddresses(windows.AF_INET6)
*/

_ = luid.FlushDNS(windows.AF_INET)
_ = luid.FlushDNS(windows.AF_INET6)

return t.tun.Close()
}
Expand Down
11 changes: 0 additions & 11 deletions pki.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,6 @@ func (p *PKI) getCertState() *CertState {
return p.cs.Load()
}

// TODO: We should remove this
func (p *PKI) getDefaultCertificate() cert.Certificate {
return p.cs.Load().GetDefaultCertificate()
}

// TODO: We should remove this
func (p *PKI) getCertificate(v cert.Version) cert.Certificate {
return p.cs.Load().getCertificate(v)
}

func (p *PKI) reload(c *config.C, initial bool) error {
err := p.reloadCerts(c, initial)
if err != nil {
Expand Down Expand Up @@ -300,7 +290,6 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
// Load the certificate
crt, rawCert, err = loadCertificate(rawCert)
if err != nil {
//TODO: check error
return nil, err
}

Expand Down
Loading

0 comments on commit 9d310e7

Please sign in to comment.