Skip to content

Commit

Permalink
PSK support for v2
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Nov 14, 2024
1 parent 5380fef commit 1487548
Show file tree
Hide file tree
Showing 9 changed files with 451 additions and 33 deletions.
15 changes: 7 additions & 8 deletions connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type ConnectionState struct {
writeLock sync.Mutex
}

func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern, psk []byte) (*ConnectionState, error) {
var dhFunc noise.DHFunc
switch crt.Curve() {
case cert.Curve_CURVE25519:
Expand Down Expand Up @@ -56,13 +56,12 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
b.Update(l, 0)

hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
//NOTE: These should come from CertState (pki.go) when we finally implement it
PresharedKey: []byte{},
CipherSuite: ncs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
PresharedKey: psk,
PresharedKeyPlacement: 0,
})
if err != nil {
Expand Down
132 changes: 132 additions & 0 deletions e2e/handshakes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,138 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
theirControl.Stop()
}

func TestPSK(t *testing.T) {
tests := []struct {
name string
myPskMode nebula.PskMode
theirPskMode nebula.PskMode
}{
// All accepting
{
name: "both accepting",
myPskMode: nebula.PskAccepting,
theirPskMode: nebula.PskAccepting,
},

// accepting and sending both ways
{
name: "accepting to sending",
myPskMode: nebula.PskAccepting,
theirPskMode: nebula.PskSending,
},
{
name: "sending to accepting",
myPskMode: nebula.PskSending,
theirPskMode: nebula.PskAccepting,
},

// All sending
{
name: "sending to sending",
myPskMode: nebula.PskSending,
theirPskMode: nebula.PskSending,
},

// enforced and sending both ways
{
name: "enforced to sending",
myPskMode: nebula.PskEnforced,
theirPskMode: nebula.PskSending,
},
{
name: "sending to enforced",
myPskMode: nebula.PskSending,
theirPskMode: nebula.PskEnforced,
},

// All enforced
{
name: "both enforced",
myPskMode: nebula.PskEnforced,
theirPskMode: nebula.PskEnforced,
},

// Enforced can technically handshake with an accepting node, but it is bad to be in this state
{
name: "enforced to accepting",
myPskMode: nebula.PskEnforced,
theirPskMode: nebula.PskAccepting,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var myPskSettings, theirPskSettings m

switch test.myPskMode {
case nebula.PskAccepting:
myPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage0", "this is a key"}}}
case nebula.PskSending:
myPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage1"}}}
case nebula.PskEnforced:
myPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage2"}}}
}

switch test.theirPskMode {
case nebula.PskAccepting:
theirPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage3", "this is a key"}}}
case nebula.PskSending:
theirPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage4"}}}
case nebula.PskEnforced:
theirPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage5"}}}
}

ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
myControl, myVpnIp, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.0.0.1/24", myPskSettings)
theirControl, theirVpnIp, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.0.0.2/24", theirPskSettings)

myControl.InjectLightHouseAddr(theirVpnIp[0].Addr(), theirUdpAddr)
r := router.NewR(t, myControl, theirControl)

// Start the servers
myControl.Start()
theirControl.Start()

t.Log("Route until we see our cached packet flow")
myControl.InjectTunUDPPacket(theirVpnIp[0].Addr(), 80, myVpnIp[0].Addr(), 80, []byte("Hi from me"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
h := &header.H{}
err := h.Parse(p.Data)
if err != nil {
panic(err)
}

// If this is the stage 1 handshake packet and I am configured to send with a psk, my cert name should
// not appear. It would likely be more obvious to unmarshal the payload and check but this works fine for now
if test.myPskMode == nebula.PskEnforced || test.myPskMode == nebula.PskSending {
if h.Type == 0 && h.MessageCounter == 1 {
assert.NotContains(t, string(p.Data), "test me")
}
}

if p.To == theirUdpAddr && h.Type == 1 {
return router.RouteAndExit
}

return router.KeepRouting
})

t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), 80, 80)

t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
assertTunnel(t, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), myControl, theirControl, r)

myControl.Stop()
theirControl.Stop()
//TODO: assert hostmaps
})
}

}

//TODO: test
// Race winner renews and handshakes
// Race loser renews and handshakes
Expand Down
7 changes: 3 additions & 4 deletions e2e/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,6 @@ type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
func NewR(t testing.TB, controls ...*nebula.Control) *R {
ctx, cancel := context.WithCancel(context.Background())

if err := os.MkdirAll("mermaid", 0755); err != nil {
panic(err)
}

r := &R{
controls: make(map[netip.AddrPort]*nebula.Control),
vpnControls: make(map[netip.Addr]*nebula.Control),
Expand Down Expand Up @@ -194,6 +190,9 @@ func (r *R) renderFlow() {
return
}

if err := os.MkdirAll(filepath.Dir(r.fn), 0755); err != nil {
panic(err)
}
f, err := os.OpenFile(r.fn, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0644)
if err != nil {
panic(err)
Expand Down
33 changes: 32 additions & 1 deletion examples/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,38 @@ pki:
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
# default_version: 1

# psk can be used to mask the contents of handshakes.
psk:
# `mode` defines how the pre shared keys can be used in a handshake.
# `accepting` (the default) will initiate handshakes using an empty key and will try to use any keys provided when
# receiving handshakes, including an empty key.
# `sending` will initiate handshakes with the first key provided and will try to use any keys provided when
# receiving handshakes, including an empty key.
# `enforced` will initiate handshakes with the first psk key provided and will try to use any keys provided when
# responding to handshakes. An empty key will not be allowed.
#
# To change a mesh from not using a psk to enforcing psk:
# 1. Leave `mode` as `accepting` and configure `psk.keys` to match on all nodes in the mesh and reload.
# 2. Change `mode` to `sending` on all nodes in the mesh and reload.
# 3. Change `mode` to `enforced` on all nodes in the mesh and reload.
#mode: accepting

# The keys provided are sent through hkdf to ensure the shared secret used in the noise protocol is the
# correct byte length.
#
# Only the first key is used for outbound handshakes but all keys provided will be tried in the order specified, on
# incoming handshakes. This is to allow for psk rotation.
#
# To rotate a primary key:
# 1. Put the new key in the 2nd slot on every node in the mesh and reload.
# 2. Move the key from the 2nd slot to the 1st slot, the old primary key is now in the 2nd slot, reload.
# 3. Remove the old primary key once it is no longer in use on every node in the mesh and reload.
#keys:
# - shared secret string, this one is used in all outbound handshakes # This is the primary key used when sending handshakes
# - this is a fallback key, received handshakes can use this
# - another fallback, received handshakes can use this one too
# - "\x68\x65\x6c\x6c\x6f\x20\x66\x72\x69\x65\x6e\x64\x73" # for raw bytes if you desire

# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
# The syntax is:
Expand Down Expand Up @@ -309,7 +341,6 @@ logging:
# after receiving the response for lighthouse queries
#trigger_buffer: 64


# Nebula security group configuration
firewall:
# Action to take when a packet is not allowed by the firewall rules.
Expand Down
59 changes: 39 additions & 20 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
Error("Unable to handshake with host because no certificate handshake bytes is available")
}

ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX, cs.psk.primary)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
Expand Down Expand Up @@ -104,34 +104,53 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Error("Unable to handshake with host because no certificate is available")
}

ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
return
}
var (
err error
ci *ConnectionState
msg []byte
)

// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
hs := &NebulaHandshake{}

msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return
for _, psk := range cs.psk.keys {
ci, err = NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX, psk)
if err != nil {
//TODO: should be bother logging this, if we have multiple psks and the error is unrelated it will be verbose.
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
continue
}

msg, _, _, err = ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
// Calls to ReadMessage with an incorrect psk should fail, try the next one if we have one
continue
}

// Sometimes ReadMessage returns fine with a nil psk even if the handshake is using a psk, ensure our protobuf
// comes out clean as well
err = hs.Unmarshal(msg)
if err == nil {
// There was no error, we can continue with this handshake
break
}

// The unmarshal failed, try the next psk if we have one
}

hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
// We finished with an error, log it and get out
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr).
// We aren't logging the error here because we can't be sure of the failure when using psk
f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
Error("Was unable to decrypt the handshake")
return
}

// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)

remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil {
e := f.l.WithError(err).WithField("udpAddr", addr).
Expand Down
5 changes: 5 additions & 0 deletions handshake_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_NewHandshakeManagerVpnIp(t *testing.T) {
Expand All @@ -23,11 +24,15 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {

lh := newTestLighthouse()

psk, err := NewPsk(PskAccepting, nil)
require.NoError(t, err)

cs := &CertState{
defaultVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
psk: psk,
}

blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
Expand Down
12 changes: 12 additions & 0 deletions pki.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type CertState struct {
pkcs11Backed bool
cipher string

psk *Psk

myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr
Expand Down Expand Up @@ -181,6 +183,16 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
}
}

psk, err := NewPskFromConfig(c)
if err != nil {
return util.NewContextualError("Failed to load psk from config", nil, err)
}
if len(psk.keys) > 0 {
p.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.keys)).
Info("pre shared keys are in use")
}
newState.psk = psk

p.cs.Store(newState)

//TODO: newState needs a stringer that does json
Expand Down
Loading

0 comments on commit 1487548

Please sign in to comment.