Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cert-v2] support compressed P256 public keys #1262

Draft
wants to merge 4 commits into
base: cert-v2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 70 additions & 60 deletions cert/ca_pool_test.go

Large diffs are not rendered by default.

71 changes: 24 additions & 47 deletions cert/cert_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package cert

import (
"bytes"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
Expand All @@ -16,12 +14,11 @@ import (
"net/netip"
"time"

"github.com/slackhq/nebula/noiseutil"
"golang.org/x/crypto/curve25519"
"google.golang.org/protobuf/proto"
)

const publicKeyLen = 32

type certificateV1 struct {
details detailsV1
signature []byte
Expand Down Expand Up @@ -110,8 +107,10 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
pubKey, err := noiseutil.LoadP256Pubkey(key)
if err != nil {
return false
}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:
Expand All @@ -127,54 +126,32 @@ func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
if curve != c.details.curve {
return fmt.Errorf("curve in cert and private key supplied don't match")
}
if curve == Curve_P256 {
return verifyP256PrivateKey(key, c.details.publicKey)
} else if curve != Curve_CURVE25519 {
return fmt.Errorf("invalid curve: %s", curve)
}

if c.details.isCA {
switch curve {
case Curve_CURVE25519:
// the call to PublicKey below will panic slice bounds out of range otherwise
if len(key) != ed25519.PrivateKeySize {
return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
}

if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return fmt.Errorf("cannot parse private key as P256: %w", err)
}
pub := privkey.PublicKey().Bytes()
if !bytes.Equal(pub, c.details.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
default:
return fmt.Errorf("invalid curve: %s", curve)
// the call to PublicKey below will panic slice bounds out of range otherwise
if len(key) != ed25519.PrivateKeySize {
return ErrInvalidPrivateKey
}
return nil
}

var pub []byte
switch curve {
case Curve_CURVE25519:
var err error
pub, err = curve25519.X25519(key, curve25519.Basepoint)
if err != nil {
return err
if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return ErrPublicPrivateKeyMismatch
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
return nil
} else {
pub, err := curve25519.X25519(key, curve25519.Basepoint)
if err != nil {
return err
return ErrInvalidPrivateKey
}
pub = privkey.PublicKey().Bytes()
default:
return fmt.Errorf("invalid curve: %s", curve)
}
if !bytes.Equal(pub, c.details.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match")
if !bytes.Equal(pub, c.details.publicKey) {
return ErrPublicPrivateKeyMismatch
}
return nil
}

return nil
}

// getRawDetails marshals the raw details into protobuf ready struct
Expand Down
31 changes: 22 additions & 9 deletions cert/cert_v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,16 @@ func TestCertificateV1_MarshalJSON(t *testing.T) {
}

func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, false)
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.Nil(t, err)

_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, false)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
assert.NotNil(t, err)

c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, false)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err)
assert.Empty(t, b)
Expand All @@ -131,24 +131,29 @@ func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
}

func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
testCertificateV1VerifyPrivateKeyP256(t, false)
testCertificateV1VerifyPrivateKeyP256(t, true)
}

func testCertificateV1VerifyPrivateKeyP256(t *testing.T, compressKey bool) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, compressKey)
err := ca.VerifyPrivateKey(Curve_P256, caKey)
assert.Nil(t, err)

_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, compressKey)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
assert.NotNil(t, err)

c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, compressKey)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err)
assert.Empty(t, b)
assert.Equal(t, Curve_P256, curve)
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
assert.Nil(t, err)

_, priv2 := P256Keypair()
_, priv2 := P256Keypair(compressKey)
err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.NotNil(t, err)
}
Expand Down Expand Up @@ -191,10 +196,18 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) {
}

func TestCertificateV1_Copy(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false)
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false)
cc := c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false)
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false)
cc = c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, true)
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, true)
cc = c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
}

func TestUnmarshalCertificateV1(t *testing.T) {
Expand Down
65 changes: 22 additions & 43 deletions cert/cert_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package cert

import (
"bytes"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/sha256"
"encoding/hex"
"encoding/json"
Expand All @@ -15,6 +13,7 @@ import (
"slices"
"time"

"github.com/slackhq/nebula/noiseutil"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
"golang.org/x/crypto/curve25519"
Expand Down Expand Up @@ -149,8 +148,10 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
pubKey, err := noiseutil.LoadP256Pubkey(key)
if err != nil {
return false
}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:
Expand All @@ -166,54 +167,32 @@ func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
if curve != c.curve {
return ErrPublicPrivateCurveMismatch
}
if curve == Curve_P256 {
return verifyP256PrivateKey(key, c.publicKey)
} else if curve != Curve_CURVE25519 {
return fmt.Errorf("invalid curve: %s", curve)
}

if c.details.isCA {
switch curve {
case Curve_CURVE25519:
// the call to PublicKey below will panic slice bounds out of range otherwise
if len(key) != ed25519.PrivateKeySize {
return ErrInvalidPrivateKey
}
// the call to PublicKey below will panic slice bounds out of range otherwise
if len(key) != ed25519.PrivateKeySize {
return ErrInvalidPrivateKey
}

if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return ErrPublicPrivateKeyMismatch
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return ErrInvalidPrivateKey
}
pub := privkey.PublicKey().Bytes()
if !bytes.Equal(pub, c.publicKey) {
return ErrPublicPrivateKeyMismatch
}
default:
return fmt.Errorf("invalid curve: %s", curve)
if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return ErrPublicPrivateKeyMismatch
}
return nil
}

var pub []byte
switch curve {
case Curve_CURVE25519:
var err error
pub, err = curve25519.X25519(key, curve25519.Basepoint)
} else {
pub, err := curve25519.X25519(key, curve25519.Basepoint)
if err != nil {
return ErrInvalidPrivateKey
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return ErrInvalidPrivateKey
if !bytes.Equal(pub, c.publicKey) {
return ErrPublicPrivateKeyMismatch
}
pub = privkey.PublicKey().Bytes()
default:
return fmt.Errorf("invalid curve: %s", curve)
}
if !bytes.Equal(pub, c.publicKey) {
return ErrPublicPrivateKeyMismatch
return nil
}

return nil
}

func (c *certificateV2) String() string {
Expand Down
33 changes: 22 additions & 11 deletions cert/cert_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
}

func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, false)
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.Nil(t, err)

Expand All @@ -142,7 +142,7 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)

c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, false)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err)
assert.Empty(t, b)
Expand All @@ -166,14 +166,14 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
err = c.VerifyPrivateKey(Curve(99), priv2)
assert.EqualError(t, err, "invalid curve: 99")

ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, false)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.Nil(t, err)

err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey)

c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil, false)
rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)

err = c.VerifyPrivateKey(Curve_P256, priv[:16])
Expand All @@ -189,35 +189,46 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
assert.EqualError(t, err, "invalid curve: 99")

}

func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
testCertificateV2VerifyPrivateKeyP256(t, false)
testCertificateV2VerifyPrivateKeyP256(t, true)
}
func testCertificateV2VerifyPrivateKeyP256(t *testing.T, compressKey bool) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, compressKey)
err := ca.VerifyPrivateKey(Curve_P256, caKey)
assert.Nil(t, err)

_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, compressKey)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
assert.NotNil(t, err)

c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, compressKey)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err)
assert.Empty(t, b)
assert.Equal(t, Curve_P256, curve)
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
assert.Nil(t, err)

_, priv2 := P256Keypair()
_, priv2 := P256Keypair(compressKey)
err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.NotNil(t, err)
}

func TestCertificateV2_Copy(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false)
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false)
cc := c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false)
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false)
cc = c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, true)
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, true)
cc = c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
}

func TestUnmarshalCertificateV2(t *testing.T) {
Expand Down
Loading
Loading