Skip to content

Commit

Permalink
finish test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
JackDoanRivian committed Nov 1, 2024
1 parent ef58b33 commit fc3dfde
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 42 deletions.
34 changes: 22 additions & 12 deletions cert/ca_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,13 @@ func TestCertificateV1_Verify(t *testing.T) {
}

func TestCertificateV1_VerifyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false)
c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false)
testCertificateV1VerifyP256(t, false)
testCertificateV1VerifyP256(t, true)
}

func testCertificateV1VerifyP256(t *testing.T, compressKey bool) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, compressKey)
c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, compressKey)

caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca))
Expand All @@ -177,11 +182,11 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
assert.EqualError(t, err, "root certificate is expired")

assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, false)
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, compressKey)
})

// Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}, false)
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}, compressKey)
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)

Expand All @@ -191,10 +196,10 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
assert.Empty(t, b)

assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}, false)
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}, compressKey)
})

c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}, false)
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}, compressKey)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
Expand Down Expand Up @@ -380,8 +385,13 @@ func TestCertificateV2_Verify(t *testing.T) {
}

func TestCertificateV2_VerifyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false)
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false)
testCertificateV2VerifyP256(t, false)
testCertificateV2VerifyP256(t, true)
}

func testCertificateV2VerifyP256(t *testing.T, compressKey bool) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, compressKey)
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, compressKey)

caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca))
Expand All @@ -401,11 +411,11 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
assert.EqualError(t, err, "root certificate is expired")

assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, false)
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, compressKey)
})

// Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}, false)
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}, compressKey)
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)

Expand All @@ -415,10 +425,10 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
assert.Empty(t, b)

assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}, false)
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}, compressKey)
})

c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}, false)
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}, compressKey)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
Expand Down
21 changes: 17 additions & 4 deletions cert/cert_v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,24 +131,29 @@ func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
}

func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, false)
testCertificateV1VerifyPrivateKeyP256(t, false)
testCertificateV1VerifyPrivateKeyP256(t, true)
}

func testCertificateV1VerifyPrivateKeyP256(t *testing.T, compressKey bool) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, compressKey)
err := ca.VerifyPrivateKey(Curve_P256, caKey)
assert.Nil(t, err)

_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, false)
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, compressKey)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
assert.NotNil(t, err)

c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, false)
c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, compressKey)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err)
assert.Empty(t, b)
assert.Equal(t, Curve_P256, curve)
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
assert.Nil(t, err)

_, priv2 := P256Keypair()
_, priv2 := P256Keypair(compressKey)
err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.NotNil(t, err)
}
Expand Down Expand Up @@ -195,6 +200,14 @@ func TestCertificateV1_Copy(t *testing.T) {
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false)
cc := c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false)
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false)
cc = c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, true)
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, true)
cc = c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
}

func TestUnmarshalCertificateV1(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion cert/cert_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func testCertificateV2VerifyPrivateKeyP256(t *testing.T, compressKey bool) {
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
assert.Nil(t, err)

_, priv2 := P256Keypair()
_, priv2 := P256Keypair(compressKey)
err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.NotNil(t, err)
}
Expand Down
33 changes: 18 additions & 15 deletions cert/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,26 @@ import (
"crypto/rand"
"io"
"net/netip"
"testing"
"time"

"github.com/slackhq/nebula/noiseutil"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
)

//todo test compress actually is different
func Test_NewTestCaCert(t *testing.T) {
c, _, priv, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(time.Hour), nil, nil, nil, false)
assert.Len(t, c.PublicKey(), 65)
c, _, priv, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(time.Hour), nil, nil, nil, true)
assert.Len(t, c.PublicKey(), 33)

cc, _, _, _ := NewTestCert(Version2, Curve_P256, c, priv, "uncompressed", time.Now(), time.Now().Add(time.Hour), nil, nil, nil, false)
assert.Len(t, cc.PublicKey(), 65)
cc, _, _, _ = NewTestCert(Version2, Curve_P256, c, priv, "compressed", time.Now(), time.Now().Add(time.Hour), nil, nil, nil, true)
assert.Len(t, cc.PublicKey(), 33)
}

// NewTestCaCert will create a new ca certificate
func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string, compressKey bool) (Certificate, []byte, []byte, []byte) {
Expand Down Expand Up @@ -88,11 +100,7 @@ func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string
case Curve_CURVE25519:
pub, priv = X25519Keypair()
case Curve_P256:
if compressKey {
pub, priv = P256KeypairCompressed()
} else {
pub, priv = P256Keypair()
}
pub, priv = P256Keypair(compressKey)
default:
panic("unknown curve")
}
Expand Down Expand Up @@ -137,19 +145,14 @@ func X25519Keypair() ([]byte, []byte) {
return pubkey, privkey
}

func P256Keypair() ([]byte, []byte) {
func P256Keypair(compressed bool) ([]byte, []byte) {
privkey, err := ecdh.P256().GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
pubkey := privkey.PublicKey()
return pubkey.Bytes(), privkey.Bytes()
}

func P256KeypairCompressed() ([]byte, []byte) {
privkey, err := ecdh.P256().GenerateKey(rand.Reader)
if err != nil {
panic(err)
if !compressed {
pubkey := privkey.PublicKey()
return pubkey.Bytes(), privkey.Bytes()
}
pubkeyBytes := privkey.PublicKey().Bytes()
pubkey, err := noiseutil.LoadP256Pubkey(pubkeyBytes)
Expand Down
29 changes: 21 additions & 8 deletions cert_test/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import (
"time"

"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/noiseutil"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
)

// NewTestCaCert will create a new ca certificate
func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string, compressKey bool) (cert.Certificate, []byte, []byte, []byte) {
var err error
var pub, priv []byte

Expand All @@ -27,8 +28,11 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti
if err != nil {
panic(err)
}

pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
if compressKey {
pub = elliptic.MarshalCompressed(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
} else {
pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
}
priv = privk.D.FillBytes(make([]byte, 32))
default:
// There is no default to allow the underlying lib to respond with an error
Expand Down Expand Up @@ -69,7 +73,7 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti

// NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in
func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string, compressKey bool) (cert.Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
Expand All @@ -83,7 +87,7 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
case cert.Curve_CURVE25519:
pub, priv = X25519Keypair()
case cert.Curve_P256:
pub, priv = P256Keypair()
pub, priv = P256Keypair(compressKey)
default:
panic("unknown curve")
}
Expand Down Expand Up @@ -128,11 +132,20 @@ func X25519Keypair() ([]byte, []byte) {
return pubkey, privkey
}

func P256Keypair() ([]byte, []byte) {
func P256Keypair(compressed bool) ([]byte, []byte) {
privkey, err := ecdh.P256().GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
pubkey := privkey.PublicKey()
return pubkey.Bytes(), privkey.Bytes()
if !compressed {
pubkey := privkey.PublicKey()
return pubkey.Bytes(), privkey.Bytes()
}
pubkeyBytes := privkey.PublicKey().Bytes()
pubkey, err := noiseutil.LoadP256Pubkey(pubkeyBytes)
if err != nil {
panic(err)
}
out := elliptic.MarshalCompressed(elliptic.P256(), pubkey.X, pubkey.Y)
return out, privkey.Bytes()
}
4 changes: 2 additions & 2 deletions service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
type m map[string]interface{}

func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
_, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
_, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}, false)
caB, err := caCrt.MarshalPEM()
if err != nil {
panic(err)
Expand Down Expand Up @@ -79,7 +79,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
}

func TestService(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, false)
a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{
"static_host_map": m{},
"lighthouse": m{
Expand Down

0 comments on commit fc3dfde

Please sign in to comment.