Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine ca, cert, and key handling #952

Merged
merged 2 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 0 additions & 163 deletions cert.go

This file was deleted.

9 changes: 2 additions & 7 deletions cmd/nebula-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,8 @@ func main() {
}

ctrl, err := nebula.Main(c, *configTest, Build, l, nil)

switch v := err.(type) {
case util.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l)
os.Exit(1)
}

Expand Down
9 changes: 2 additions & 7 deletions cmd/nebula/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,8 @@ func main() {
}

ctrl, err := nebula.Main(c, *configTest, Build, l, nil)

switch v := err.(type) {
case util.ContextualError:
v.Log(l)
os.Exit(1)
case error:
l.WithError(err).Error("Failed to start")
if err != nil {
util.LogWithContextIfNeeded("Failed to start", err, l)
os.Exit(1)
}

Expand Down
10 changes: 5 additions & 5 deletions connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
return false
}

certState := n.intf.certState.Load()
return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)
certState := n.intf.pki.GetCertState()
return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature)
}

func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
Expand All @@ -427,7 +427,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
return false
}

valid, err := remoteCert.VerifyWithCache(now, n.intf.caPool)
valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
if valid {
return false
}
Expand Down Expand Up @@ -464,8 +464,8 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}

func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.certState.Load()
if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {
certState := n.intf.pki.GetCertState()
if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) {
return
}

Expand Down
35 changes: 19 additions & 16 deletions connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
// Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
certificate: &cert.NebulaCertificate{},
rawCertificateNoKey: []byte{},
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
}

lh := newTestLighthouse()
Expand All @@ -57,10 +57,11 @@ func Test_NewConnectionManagerTest(t *testing.T) {
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.certState.Store(cs)
ifce.pki.cs.Store(cs)

// Create manager
ctx, cancel := context.WithCancel(context.Background())
Expand Down Expand Up @@ -123,10 +124,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
certificate: &cert.NebulaCertificate{},
rawCertificateNoKey: []byte{},
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
}

lh := newTestLighthouse()
Expand All @@ -136,10 +137,11 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.certState.Store(cs)
ifce.pki.cs.Store(cs)

// Create manager
ctx, cancel := context.WithCancel(context.Background())
Expand Down Expand Up @@ -242,10 +244,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
peerCert.Sign(cert.Curve_CURVE25519, privCA)

cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
certificate: &cert.NebulaCertificate{},
rawCertificateNoKey: []byte{},
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
}

lh := newTestLighthouse()
Expand All @@ -258,9 +260,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
disconnectInvalid: true,
caPool: ncp,
pki: &PKI{},
}
ifce.certState.Store(cs)
ifce.pki.cs.Store(cs)
ifce.pki.caPool.Store(ncp)

// Create manager
ctx, cancel := context.WithCancel(context.Background())
Expand Down
8 changes: 4 additions & 4 deletions connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ type ConnectionState struct {

func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
var dhFunc noise.DHFunc
curCertState := f.certState.Load()
curCertState := f.pki.GetCertState()

switch curCertState.certificate.Details.Curve {
switch curCertState.Certificate.Details.Curve {
case cert.Curve_CURVE25519:
dhFunc = noise.DH25519
case cert.Curve_P256:
dhFunc = noiseutil.DHP256
default:
l.Errorf("invalid curve: %s", curCertState.certificate.Details.Curve)
l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve)
return nil
}
cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
if f.cipher == "chachapoly" {
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
}

static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey}

b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
Expand Down
2 changes: 1 addition & 1 deletion control_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (c *Control) GetHostmap() *HostMap {
}

func (c *Control) GetCert() *cert.NebulaCertificate {
return c.f.certState.Load().certificate
return c.f.pki.GetCertState().Certificate
}

func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
Expand Down
8 changes: 4 additions & 4 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
hsProto := &NebulaHandshakeDetails{
InitiatorIndex: hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: ci.certState.rawCertificateNoKey,
Cert: ci.certState.RawCertificateNoKey,
}

hsBytes := []byte{}
Expand Down Expand Up @@ -91,7 +91,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
return
}

remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Expand Down Expand Up @@ -155,7 +155,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
Info("Handshake message received")

hs.Details.ResponderIndex = myIndex
hs.Details.Cert = ci.certState.rawCertificateNoKey
hs.Details.Cert = ci.certState.RawCertificateNoKey
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())

Expand Down Expand Up @@ -399,7 +399,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
return true
}

remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Expand Down
Loading