Skip to content

Commit

Permalink
Combine ca, cert, and key handling
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Aug 8, 2023
1 parent 223cc6e commit 7293838
Show file tree
Hide file tree
Showing 16 changed files with 331 additions and 287 deletions.
163 changes: 0 additions & 163 deletions cert.go

This file was deleted.

17 changes: 9 additions & 8 deletions cmd/nebula-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,15 @@ func main() {
}

ctrl, err := nebula.Main(c, *configTest, Build, l, nil)

switch v := err.(type) {
case util.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
if err != nil {
switch v := err.(type) {
case *util.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
}

if !*configTest {
Expand Down
17 changes: 9 additions & 8 deletions cmd/nebula/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ func main() {
}

ctrl, err := nebula.Main(c, *configTest, Build, l, nil)

switch v := err.(type) {
case util.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
if err != nil {
switch v := err.(type) {
case *util.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
os.Exit(1)
}
}

if !*configTest {
Expand Down
10 changes: 5 additions & 5 deletions connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
return false
}

certState := n.intf.certState.Load()
return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)
certState := n.intf.pki.GetCertState()
return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature)
}

func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
Expand All @@ -427,7 +427,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
return false
}

valid, err := remoteCert.VerifyWithCache(now, n.intf.caPool)
valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
if valid {
return false
}
Expand Down Expand Up @@ -464,8 +464,8 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}

func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.certState.Load()
if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {
certState := n.intf.pki.GetCertState()
if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) {
return
}

Expand Down
35 changes: 19 additions & 16 deletions connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
// Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
certificate: &cert.NebulaCertificate{},
rawCertificateNoKey: []byte{},
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
}

lh := newTestLighthouse()
Expand All @@ -57,10 +57,11 @@ func Test_NewConnectionManagerTest(t *testing.T) {
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.certState.Store(cs)
ifce.pki.cs.Store(cs)

// Create manager
ctx, cancel := context.WithCancel(context.Background())
Expand Down Expand Up @@ -123,10 +124,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
certificate: &cert.NebulaCertificate{},
rawCertificateNoKey: []byte{},
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
}

lh := newTestLighthouse()
Expand All @@ -136,10 +137,11 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.certState.Store(cs)
ifce.pki.cs.Store(cs)

// Create manager
ctx, cancel := context.WithCancel(context.Background())
Expand Down Expand Up @@ -242,10 +244,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
peerCert.Sign(cert.Curve_CURVE25519, privCA)

cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
certificate: &cert.NebulaCertificate{},
rawCertificateNoKey: []byte{},
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
}

lh := newTestLighthouse()
Expand All @@ -258,9 +260,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
disconnectInvalid: true,
caPool: ncp,
pki: &PKI{},
}
ifce.certState.Store(cs)
ifce.pki.cs.Store(cs)
ifce.pki.caPool.Store(ncp)

// Create manager
ctx, cancel := context.WithCancel(context.Background())
Expand Down
8 changes: 4 additions & 4 deletions connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ type ConnectionState struct {

func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
var dhFunc noise.DHFunc
curCertState := f.certState.Load()
curCertState := f.pki.GetCertState()

switch curCertState.certificate.Details.Curve {
switch curCertState.Certificate.Details.Curve {
case cert.Curve_CURVE25519:
dhFunc = noise.DH25519
case cert.Curve_P256:
dhFunc = noiseutil.DHP256
default:
l.Errorf("invalid curve: %s", curCertState.certificate.Details.Curve)
l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve)
return nil
}
cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
if f.cipher == "chachapoly" {
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
}

static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey}

b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
Expand Down
2 changes: 1 addition & 1 deletion control_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (c *Control) GetHostmap() *HostMap {
}

func (c *Control) GetCert() *cert.NebulaCertificate {
return c.f.certState.Load().certificate
return c.f.pki.GetCertState().Certificate
}

func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
Expand Down
8 changes: 4 additions & 4 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
hsProto := &NebulaHandshakeDetails{
InitiatorIndex: hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: ci.certState.rawCertificateNoKey,
Cert: ci.certState.RawCertificateNoKey,
}

hsBytes := []byte{}
Expand Down Expand Up @@ -91,7 +91,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
return
}

remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Expand Down Expand Up @@ -155,7 +155,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
Info("Handshake message received")

hs.Details.ResponderIndex = myIndex
hs.Details.Cert = ci.certState.rawCertificateNoKey
hs.Details.Cert = ci.certState.RawCertificateNoKey
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())

Expand Down Expand Up @@ -399,7 +399,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
return true
}

remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Expand Down
Loading

0 comments on commit 7293838

Please sign in to comment.