Skip to content

Commit

Permalink
use curve info from handshake for v2 cert validation
Browse files Browse the repository at this point in the history
  • Loading branch information
JackDoanRivian committed Sep 20, 2024
1 parent 8bf2f2d commit e59be7f
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 21 deletions.
15 changes: 10 additions & 5 deletions cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ type CachedCertificate struct {
func UnmarshalCertificate(b []byte) (Certificate, error) {
//TODO: you left off here, no one uses this function but it might be beneficial to export _something_ that someone can use, maybe the Versioned unmarshallsers?
var c Certificate
c, err := unmarshalCertificateV2(b, nil)
c, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
if err == nil {
return c, nil
}
Expand All @@ -129,15 +129,15 @@ func UnmarshalCertificate(b []byte) (Certificate, error) {
// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake.
// Handshakes save space by placing the peers public key in a different part of the packet, we have to
// reassemble the actual certificate structure with that in mind.
func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (Certificate, error) {
func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
var c Certificate
var err error

switch v {
case VersionPre1, Version1:
c, err = unmarshalCertificateV1(b, publicKey)
case Version2:
c, err = unmarshalCertificateV2(b, publicKey)
c, err = unmarshalCertificateV2(b, publicKey, curve)
default:
//TODO: make a static var
return nil, fmt.Errorf("unknown certificate version %d", v)
Expand All @@ -146,10 +146,15 @@ func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (C
if err != nil {
return nil, err
}

if c.Curve() != curve {
return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
}

return c, nil
}

func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, caPool *CAPool) (*CachedCertificate, error) {
func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
if publicKey == nil {
return nil, ErrNoPeerStaticKey
}
Expand All @@ -158,7 +163,7 @@ func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, caPool *CAP
return nil, ErrNoPayload
}

c, err := UnmarshalCertificateFromHandshake(v, rawCertBytes, publicKey)
c, err := UnmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
}
Expand Down
18 changes: 5 additions & 13 deletions cert/cert_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,16 +231,7 @@ func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
//TODO: panic on nil rawDetails
b.AddBytes(c.rawDetails)

// Skipping public key since those come across in a different part of the handshake

//todo is curve skippable? I don't think so?

// Add the curve only if its not the default value
if c.curve != Curve_CURVE25519 {
b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
b.AddBytes([]byte{byte(c.curve)})
})
}
// Skipping the curve and since those come across in a different part of the handshake

// Add the signature
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
Expand Down Expand Up @@ -464,7 +455,7 @@ func (d *detailsV2) Marshal() ([]byte, error) {
return b.Bytes()
}

func unmarshalCertificateV2(b []byte, publicKey []byte) (*certificateV2, error) {
func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
l := len(b)
if l == 0 || l > MaxCertificateSize {
return nil, ErrBadFormat
Expand All @@ -482,11 +473,12 @@ func unmarshalCertificateV2(b []byte, publicKey []byte) (*certificateV2, error)
return nil, ErrBadFormat
}

//Maybe grab the curve
var rawCurve byte
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(Curve_CURVE25519)) {
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
return nil, ErrBadFormat
}
curve := Curve(rawCurve)
curve = Curve(rawCurve)

// Maybe grab the public key
var rawPublicKey cryptobyte.String
Expand Down
2 changes: 1 addition & 1 deletion cert/pem.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
case CertificateBanner:
c, err = unmarshalCertificateV1(p.Bytes, nil)
case CertificateV2Banner:
c, err = unmarshalCertificateV2(p.Bytes, nil)
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
default:
return nil, r, ErrInvalidPEMCertificateBanner
}
Expand Down
4 changes: 4 additions & 0 deletions connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
"message_counter": cs.messageCounter.Load(),
})
}

func (cs *ConnectionState) Curve() cert.Curve {
return cs.myCert.Curve()
}
4 changes: 2 additions & 2 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
}

remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool())
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil {
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
Expand Down Expand Up @@ -404,7 +404,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
return true
}

remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool())
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil {
e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
Expand Down

0 comments on commit e59be7f

Please sign in to comment.