diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index f03b2ba83..27f2767a0 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -111,8 +111,8 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe } func TestCertificateV1_Verify(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false) caPool := NewCAPool() assert.NoError(t, caPool.AddCA(ca)) @@ -132,11 +132,11 @@ func TestCertificateV1_Verify(t *testing.T) { assert.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { - NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil, false) }) // Test group assertion - ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}, false) caPem, err := ca.MarshalPEM() assert.Nil(t, err) @@ -146,18 +146,23 @@ func TestCertificateV1_Verify(t *testing.T) { assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { - NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}, false) }) - c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } func TestCertificateV1_VerifyP256(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + testCertificateV1VerifyP256(t, false) + testCertificateV1VerifyP256(t, true) +} + +func testCertificateV1VerifyP256(t *testing.T, compressKey bool) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, compressKey) + c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, compressKey) caPool := NewCAPool() assert.NoError(t, caPool.AddCA(ca)) @@ -177,11 +182,11 @@ func TestCertificateV1_VerifyP256(t *testing.T) { assert.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { - NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, compressKey) }) // Test group assertion - ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}, compressKey) caPem, err := ca.MarshalPEM() assert.Nil(t, err) @@ -191,10 +196,10 @@ func TestCertificateV1_VerifyP256(t *testing.T) { assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { - NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}, compressKey) }) - c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}, compressKey) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } @@ -202,7 +207,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { func TestCertificateV1_Verify_IPs(t *testing.T) { caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") - ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}, false) caPem, err := ca.MarshalPEM() assert.Nil(t, err) @@ -216,51 +221,51 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { - NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}, false) }) // ip is outside the network reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { - NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}, false) }) // ip is within the network but mask is outside cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { - NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}, false) }) // ip is within the network but mask is outside reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { - NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}, false) }) // ip and mask are within the network cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") - c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}, false) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches - c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed - c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed with just 1 - c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) @@ -269,7 +274,7 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { func TestCertificateV1_Verify_Subnets(t *testing.T) { caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") - ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}, false) caPem, err := ca.MarshalPEM() assert.Nil(t, err) @@ -283,60 +288,60 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) { cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { - NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}, false) }) // ip is outside the network reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { - NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}, false) }) // ip is within the network but mask is outside cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { - NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}, false) }) // ip is within the network but mask is outside reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { - NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}, false) }) // ip and mask are within the network cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") - c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches - c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed - c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed with just 1 - c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } func TestCertificateV2_Verify(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false) caPool := NewCAPool() assert.NoError(t, caPool.AddCA(ca)) @@ -356,11 +361,11 @@ func TestCertificateV2_Verify(t *testing.T) { assert.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { - NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil, false) }) // Test group assertion - ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}, false) caPem, err := ca.MarshalPEM() assert.Nil(t, err) @@ -370,18 +375,23 @@ func TestCertificateV2_Verify(t *testing.T) { assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { - NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}, false) }) - c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } func TestCertificateV2_VerifyP256(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + testCertificateV2VerifyP256(t, false) + testCertificateV2VerifyP256(t, true) +} + +func testCertificateV2VerifyP256(t *testing.T, compressKey bool) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, compressKey) + c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, compressKey) caPool := NewCAPool() assert.NoError(t, caPool.AddCA(ca)) @@ -401,11 +411,11 @@ func TestCertificateV2_VerifyP256(t *testing.T) { assert.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { - NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, compressKey) }) // Test group assertion - ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}, compressKey) caPem, err := ca.MarshalPEM() assert.Nil(t, err) @@ -415,10 +425,10 @@ func TestCertificateV2_VerifyP256(t *testing.T) { assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { - NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}, compressKey) }) - c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}, compressKey) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } @@ -426,7 +436,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { func TestCertificateV2_Verify_IPs(t *testing.T) { caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") - ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}, false) caPem, err := ca.MarshalPEM() assert.Nil(t, err) @@ -440,51 +450,51 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { - NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}, false) }) // ip is outside the network reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { - NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}, false) }) // ip is within the network but mask is outside cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { - NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}, false) }) // ip is within the network but mask is outside reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { - NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}, false) }) // ip and mask are within the network cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") - c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}, false) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches - c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed - c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed with just 1 - c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) @@ -493,7 +503,7 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { func TestCertificateV2_Verify_Subnets(t *testing.T) { caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") - ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}, false) caPem, err := ca.MarshalPEM() assert.Nil(t, err) @@ -507,52 +517,52 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { - NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}, false) }) // ip is outside the network reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { - NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}, false) }) // ip is within the network but mask is outside cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { - NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}, false) }) // ip is within the network but mask is outside reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { - NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}, false) }) // ip and mask are within the network cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") - c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches - c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed - c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed with just 1 - c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}, false) assert.Nil(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) diff --git a/cert/cert_v1.go b/cert/cert_v1.go index b807f8d21..44f23cc51 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -2,10 +2,8 @@ package cert import ( "bytes" - "crypto/ecdh" "crypto/ecdsa" "crypto/ed25519" - "crypto/elliptic" "crypto/sha256" "encoding/binary" "encoding/hex" @@ -16,12 +14,11 @@ import ( "net/netip" "time" + "github.com/slackhq/nebula/noiseutil" "golang.org/x/crypto/curve25519" "google.golang.org/protobuf/proto" ) -const publicKeyLen = 32 - type certificateV1 struct { details detailsV1 signature []byte @@ -110,8 +107,10 @@ func (c *certificateV1) CheckSignature(key []byte) bool { case Curve_CURVE25519: return ed25519.Verify(key, b, c.signature) case Curve_P256: - x, y := elliptic.Unmarshal(elliptic.P256(), key) - pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + pubKey, err := noiseutil.LoadP256Pubkey(key) + if err != nil { + return false + } hashed := sha256.Sum256(b) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) default: @@ -127,54 +126,32 @@ func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { if curve != c.details.curve { return fmt.Errorf("curve in cert and private key supplied don't match") } + if curve == Curve_P256 { + return verifyP256PrivateKey(key, c.details.publicKey) + } else if curve != Curve_CURVE25519 { + return fmt.Errorf("invalid curve: %s", curve) + } + if c.details.isCA { - switch curve { - case Curve_CURVE25519: - // the call to PublicKey below will panic slice bounds out of range otherwise - if len(key) != ed25519.PrivateKeySize { - return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") - } - - if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) { - return fmt.Errorf("public key in cert and private key supplied don't match") - } - case Curve_P256: - privkey, err := ecdh.P256().NewPrivateKey(key) - if err != nil { - return fmt.Errorf("cannot parse private key as P256: %w", err) - } - pub := privkey.PublicKey().Bytes() - if !bytes.Equal(pub, c.details.publicKey) { - return fmt.Errorf("public key in cert and private key supplied don't match") - } - default: - return fmt.Errorf("invalid curve: %s", curve) + // the call to PublicKey below will panic slice bounds out of range otherwise + if len(key) != ed25519.PrivateKeySize { + return ErrInvalidPrivateKey } - return nil - } - var pub []byte - switch curve { - case Curve_CURVE25519: - var err error - pub, err = curve25519.X25519(key, curve25519.Basepoint) - if err != nil { - return err + if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) { + return ErrPublicPrivateKeyMismatch } - case Curve_P256: - privkey, err := ecdh.P256().NewPrivateKey(key) + return nil + } else { + pub, err := curve25519.X25519(key, curve25519.Basepoint) if err != nil { - return err + return ErrInvalidPrivateKey } - pub = privkey.PublicKey().Bytes() - default: - return fmt.Errorf("invalid curve: %s", curve) - } - if !bytes.Equal(pub, c.details.publicKey) { - return fmt.Errorf("public key in cert and private key supplied don't match") + if !bytes.Equal(pub, c.details.publicKey) { + return ErrPublicPrivateKeyMismatch + } + return nil } - - return nil } // getRawDetails marshals the raw details into protobuf ready struct diff --git a/cert/cert_v1_test.go b/cert/cert_v1_test.go index 8c3fe930b..8301ae070 100644 --- a/cert/cert_v1_test.go +++ b/cert/cert_v1_test.go @@ -108,16 +108,16 @@ func TestCertificateV1_MarshalJSON(t *testing.T) { } func TestCertificateV1_VerifyPrivateKey(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, false) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) assert.Nil(t, err) - _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, false) assert.Nil(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) assert.NotNil(t, err) - c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, false) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) assert.NoError(t, err) assert.Empty(t, b) @@ -131,16 +131,21 @@ func TestCertificateV1_VerifyPrivateKey(t *testing.T) { } func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + testCertificateV1VerifyPrivateKeyP256(t, false) + testCertificateV1VerifyPrivateKeyP256(t, true) +} + +func testCertificateV1VerifyPrivateKeyP256(t *testing.T, compressKey bool) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, compressKey) err := ca.VerifyPrivateKey(Curve_P256, caKey) assert.Nil(t, err) - _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, compressKey) assert.Nil(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) assert.NotNil(t, err) - c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, compressKey) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) assert.NoError(t, err) assert.Empty(t, b) @@ -148,7 +153,7 @@ func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { err = c.VerifyPrivateKey(Curve_P256, rawPriv) assert.Nil(t, err) - _, priv2 := P256Keypair() + _, priv2 := P256Keypair(compressKey) err = c.VerifyPrivateKey(Curve_P256, priv2) assert.NotNil(t, err) } @@ -191,10 +196,18 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) { } func TestCertificateV1_Copy(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false) cc := c.Copy() test.AssertDeepCopyEqual(t, c, cc) + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false) + c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false) + cc = c.Copy() + test.AssertDeepCopyEqual(t, c, cc) + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, true) + c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, true) + cc = c.Copy() + test.AssertDeepCopyEqual(t, c, cc) } func TestUnmarshalCertificateV1(t *testing.T) { diff --git a/cert/cert_v2.go b/cert/cert_v2.go index dce929684..e179957f0 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -2,10 +2,8 @@ package cert import ( "bytes" - "crypto/ecdh" "crypto/ecdsa" "crypto/ed25519" - "crypto/elliptic" "crypto/sha256" "encoding/hex" "encoding/json" @@ -15,6 +13,7 @@ import ( "slices" "time" + "github.com/slackhq/nebula/noiseutil" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" "golang.org/x/crypto/curve25519" @@ -149,8 +148,10 @@ func (c *certificateV2) CheckSignature(key []byte) bool { case Curve_CURVE25519: return ed25519.Verify(key, b, c.signature) case Curve_P256: - x, y := elliptic.Unmarshal(elliptic.P256(), key) - pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + pubKey, err := noiseutil.LoadP256Pubkey(key) + if err != nil { + return false + } hashed := sha256.Sum256(b) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) default: @@ -166,54 +167,32 @@ func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error { if curve != c.curve { return ErrPublicPrivateCurveMismatch } + if curve == Curve_P256 { + return verifyP256PrivateKey(key, c.publicKey) + } else if curve != Curve_CURVE25519 { + return fmt.Errorf("invalid curve: %s", curve) + } + if c.details.isCA { - switch curve { - case Curve_CURVE25519: - // the call to PublicKey below will panic slice bounds out of range otherwise - if len(key) != ed25519.PrivateKeySize { - return ErrInvalidPrivateKey - } + // the call to PublicKey below will panic slice bounds out of range otherwise + if len(key) != ed25519.PrivateKeySize { + return ErrInvalidPrivateKey + } - if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) { - return ErrPublicPrivateKeyMismatch - } - case Curve_P256: - privkey, err := ecdh.P256().NewPrivateKey(key) - if err != nil { - return ErrInvalidPrivateKey - } - pub := privkey.PublicKey().Bytes() - if !bytes.Equal(pub, c.publicKey) { - return ErrPublicPrivateKeyMismatch - } - default: - return fmt.Errorf("invalid curve: %s", curve) + if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) { + return ErrPublicPrivateKeyMismatch } return nil - } - - var pub []byte - switch curve { - case Curve_CURVE25519: - var err error - pub, err = curve25519.X25519(key, curve25519.Basepoint) + } else { + pub, err := curve25519.X25519(key, curve25519.Basepoint) if err != nil { return ErrInvalidPrivateKey } - case Curve_P256: - privkey, err := ecdh.P256().NewPrivateKey(key) - if err != nil { - return ErrInvalidPrivateKey + if !bytes.Equal(pub, c.publicKey) { + return ErrPublicPrivateKeyMismatch } - pub = privkey.PublicKey().Bytes() - default: - return fmt.Errorf("invalid curve: %s", curve) - } - if !bytes.Equal(pub, c.publicKey) { - return ErrPublicPrivateKeyMismatch + return nil } - - return nil } func (c *certificateV2) String() string { diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go index 3afbcab14..b50cc39e6 100644 --- a/cert/cert_v2_test.go +++ b/cert/cert_v2_test.go @@ -130,7 +130,7 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { } func TestCertificateV2_VerifyPrivateKey(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil, false) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) assert.Nil(t, err) @@ -142,7 +142,7 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) { err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) - c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, false) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) assert.NoError(t, err) assert.Empty(t, b) @@ -166,14 +166,14 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) { err = c.VerifyPrivateKey(Curve(99), priv2) assert.EqualError(t, err, "invalid curve: 99") - ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, false) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) assert.Nil(t, err) err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) assert.ErrorIs(t, err, ErrInvalidPrivateKey) - c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) + c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil, false) rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) err = c.VerifyPrivateKey(Curve_P256, priv[:16]) @@ -189,18 +189,21 @@ func TestCertificateV2_VerifyPrivateKey(t *testing.T) { assert.EqualError(t, err, "invalid curve: 99") } - func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + testCertificateV2VerifyPrivateKeyP256(t, false) + testCertificateV2VerifyPrivateKeyP256(t, true) +} +func testCertificateV2VerifyPrivateKeyP256(t *testing.T, compressKey bool) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, compressKey) err := ca.VerifyPrivateKey(Curve_P256, caKey) assert.Nil(t, err) - _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil, compressKey) assert.Nil(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) assert.NotNil(t, err) - c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil, compressKey) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) assert.NoError(t, err) assert.Empty(t, b) @@ -208,16 +211,24 @@ func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { err = c.VerifyPrivateKey(Curve_P256, rawPriv) assert.Nil(t, err) - _, priv2 := P256Keypair() + _, priv2 := P256Keypair(compressKey) err = c.VerifyPrivateKey(Curve_P256, priv2) assert.NotNil(t, err) } func TestCertificateV2_Copy(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false) cc := c.Copy() test.AssertDeepCopyEqual(t, c, cc) + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, false) + c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, false) + cc = c.Copy() + test.AssertDeepCopyEqual(t, c, cc) + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil, true) + c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil, true) + cc = c.Copy() + test.AssertDeepCopyEqual(t, c, cc) } func TestUnmarshalCertificateV2(t *testing.T) { diff --git a/cert/helper_test.go b/cert/helper_test.go index 05142dd54..acfdeffa1 100644 --- a/cert/helper_test.go +++ b/cert/helper_test.go @@ -7,14 +7,29 @@ import ( "crypto/rand" "io" "net/netip" + "testing" "time" + "github.com/slackhq/nebula/noiseutil" + "github.com/stretchr/testify/assert" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/ed25519" ) +func Test_NewTestCaCert(t *testing.T) { + c, _, priv, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(time.Hour), nil, nil, nil, false) + assert.Len(t, c.PublicKey(), 65) + c, _, priv, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(time.Hour), nil, nil, nil, true) + assert.Len(t, c.PublicKey(), 33) + + cc, _, _, _ := NewTestCert(Version2, Curve_P256, c, priv, "uncompressed", time.Now(), time.Now().Add(time.Hour), nil, nil, nil, false) + assert.Len(t, cc.PublicKey(), 65) + cc, _, _, _ = NewTestCert(Version2, Curve_P256, c, priv, "compressed", time.Now(), time.Now().Add(time.Hour), nil, nil, nil, true) + assert.Len(t, cc.PublicKey(), 33) +} + // NewTestCaCert will create a new ca certificate -func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { +func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string, compressKey bool) (Certificate, []byte, []byte, []byte) { var err error var pub, priv []byte @@ -26,8 +41,11 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ if err != nil { panic(err) } - - pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + if compressKey { + pub = elliptic.MarshalCompressed(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + } else { + pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + } priv = privk.D.FillBytes(make([]byte, 32)) default: // There is no default to allow the underlying lib to respond with an error @@ -68,7 +86,7 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in -func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { +func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string, compressKey bool) (Certificate, []byte, []byte, []byte) { if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } @@ -82,7 +100,7 @@ func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string case Curve_CURVE25519: pub, priv = X25519Keypair() case Curve_P256: - pub, priv = P256Keypair() + pub, priv = P256Keypair(compressKey) default: panic("unknown curve") } @@ -127,11 +145,20 @@ func X25519Keypair() ([]byte, []byte) { return pubkey, privkey } -func P256Keypair() ([]byte, []byte) { +func P256Keypair(compressed bool) ([]byte, []byte) { privkey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { panic(err) } - pubkey := privkey.PublicKey() - return pubkey.Bytes(), privkey.Bytes() + if !compressed { + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() + } + pubkeyBytes := privkey.PublicKey().Bytes() + pubkey, err := noiseutil.LoadP256Pubkey(pubkeyBytes) + if err != nil { + panic(err) + } + out := elliptic.MarshalCompressed(elliptic.P256(), pubkey.X, pubkey.Y) + return out, privkey.Bytes() } diff --git a/cert/sign.go b/cert/sign.go index a1e09cd2b..b5f880747 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -1,6 +1,7 @@ package cert import ( + "crypto/ecdh" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" @@ -11,6 +12,8 @@ import ( "net/netip" "slices" "time" + + "github.com/slackhq/nebula/noiseutil" ) // TBSCertificate represents a certificate intended to be signed. @@ -158,3 +161,18 @@ func comparePrefix(a, b netip.Prefix) int { } return addr } + +func verifyP256PrivateKey(privateKey []byte, detailsPublicBytes []byte) error { + privkey, err := ecdh.P256().NewPrivateKey(privateKey) + if err != nil { + return ErrInvalidPrivateKey + } + detailsPubkey, err := noiseutil.LoadECDHPubkey(detailsPublicBytes) + if err != nil { + return fmt.Errorf("cannot parse public key from cert as P256: %w", err) + } + if !detailsPubkey.Equal(privkey.PublicKey()) { + return ErrPublicPrivateKeyMismatch + } + return nil +} diff --git a/cert_test/cert.go b/cert_test/cert.go index ebc6f522d..a6e039eb3 100644 --- a/cert_test/cert.go +++ b/cert_test/cert.go @@ -10,12 +10,13 @@ import ( "time" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/noiseutil" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/ed25519" ) // NewTestCaCert will create a new ca certificate -func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { +func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string, compressKey bool) (cert.Certificate, []byte, []byte, []byte) { var err error var pub, priv []byte @@ -27,8 +28,11 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti if err != nil { panic(err) } - - pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + if compressKey { + pub = elliptic.MarshalCompressed(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + } else { + pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + } priv = privk.D.FillBytes(make([]byte, 32)) default: // There is no default to allow the underlying lib to respond with an error @@ -69,7 +73,7 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in -func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { +func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string, compressKey bool) (cert.Certificate, []byte, []byte, []byte) { if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } @@ -83,7 +87,7 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by case cert.Curve_CURVE25519: pub, priv = X25519Keypair() case cert.Curve_P256: - pub, priv = P256Keypair() + pub, priv = P256Keypair(compressKey) default: panic("unknown curve") } @@ -128,11 +132,20 @@ func X25519Keypair() ([]byte, []byte) { return pubkey, privkey } -func P256Keypair() ([]byte, []byte) { +func P256Keypair(compressed bool) ([]byte, []byte) { privkey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { panic(err) } - pubkey := privkey.PublicKey() - return pubkey.Bytes(), privkey.Bytes() + if !compressed { + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() + } + pubkeyBytes := privkey.PublicKey().Bytes() + pubkey, err := noiseutil.LoadP256Pubkey(pubkeyBytes) + if err != nil { + panic(err) + } + out := elliptic.MarshalCompressed(elliptic.P256(), pubkey.X, pubkey.Y) + return out, privkey.Bytes() } diff --git a/cmd/nebula-cert/keygen.go b/cmd/nebula-cert/keygen.go index 496f84c27..aeb3d631f 100644 --- a/cmd/nebula-cert/keygen.go +++ b/cmd/nebula-cert/keygen.go @@ -62,7 +62,7 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { pub, rawPriv = x25519Keypair() curve = cert.Curve_CURVE25519 case "P256": - pub, rawPriv = p256Keypair() + pub, rawPriv = p256Keypair(false) //todo support generating compressed keys curve = cert.Curve_P256 default: return fmt.Errorf("invalid curve: %s", *cf.curve) diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 253ef864e..8ac8c0bba 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -2,6 +2,7 @@ package main import ( "crypto/ecdh" + "crypto/elliptic" "crypto/rand" "errors" "flag" @@ -14,6 +15,7 @@ import ( "github.com/skip2/go-qrcode" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/curve25519" ) @@ -400,7 +402,7 @@ func newKeypair(curve cert.Curve) ([]byte, []byte) { case cert.Curve_CURVE25519: return x25519Keypair() case cert.Curve_P256: - return p256Keypair() + return p256Keypair(false) //todo support generating compressed keys default: return nil, nil } @@ -420,13 +422,22 @@ func x25519Keypair() ([]byte, []byte) { return pubkey, privkey } -func p256Keypair() ([]byte, []byte) { +func p256Keypair(compressed bool) ([]byte, []byte) { privkey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { panic(err) } - pubkey := privkey.PublicKey() - return pubkey.Bytes(), privkey.Bytes() + if !compressed { + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() + } + pubkeyBytes := privkey.PublicKey().Bytes() + pubkey, err := noiseutil.LoadP256Pubkey(pubkeyBytes) + if err != nil { + panic(err) + } + out := elliptic.MarshalCompressed(elliptic.P256(), pubkey.X, pubkey.Y) + return out, privkey.Bytes() } func signSummary() string { diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 29b9d536e..5114f1720 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -20,8 +20,10 @@ import ( "gopkg.in/yaml.v2" ) +const compressKey = false + func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -45,7 +47,7 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -96,7 +98,7 @@ func TestGoodHandshake(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) @@ -175,7 +177,7 @@ func TestWrongResponderHandshake(t *testing.T) { } func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) @@ -263,7 +265,7 @@ func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -340,7 +342,7 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -389,7 +391,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -440,7 +442,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) @@ -471,7 +473,7 @@ func TestRelays(t *testing.T) { func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) @@ -520,7 +522,7 @@ func TestStage1RaceRelays(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) @@ -608,7 +610,7 @@ func TestStage1RaceRelays2(t *testing.T) { } func TestRehandshakingRelays(t *testing.T) { - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) @@ -638,7 +640,7 @@ func TestRehandshakingRelays(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}, compressKey) caB, err := ca.MarshalPEM() if err != nil { @@ -712,7 +714,7 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) @@ -742,7 +744,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}, compressKey) caB, err := ca.MarshalPEM() if err != nil { @@ -815,7 +817,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) @@ -837,7 +839,7 @@ func TestRehandshaking(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") - _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}, compressKey) caB, err := ca.MarshalPEM() if err != nil { @@ -912,7 +914,7 @@ func TestRehandshaking(t *testing.T) { func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) @@ -934,7 +936,7 @@ func TestRehandshakingLoser(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") - _, _, theirNextPrivKey, theirNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) + _, _, theirNextPrivKey, theirNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}, compressKey) caB, err := ca.MarshalPEM() if err != nil { @@ -1008,7 +1010,7 @@ func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) @@ -1066,7 +1068,7 @@ func TestRaceRegression(t *testing.T) { } func TestV2NonPrimaryWithLighthouse(t *testing.T) { - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, compressKey) lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}}) o := m{ diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index f5df2fb25..1ba57aea4 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -56,7 +56,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } - _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) + _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}, compressKey) caB, err := caCrt.MarshalPEM() if err != nil { diff --git a/noiseutil/nist.go b/noiseutil/nist.go index 976a27423..6a0983a9d 100644 --- a/noiseutil/nist.go +++ b/noiseutil/nist.go @@ -2,9 +2,13 @@ package noiseutil import ( "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" + "errors" "fmt" "io" + "math/big" "github.com/flynn/noise" ) @@ -41,8 +45,44 @@ func (c nistCurve) GenerateKeypair(rng io.Reader) (noise.DHKey, error) { return noise.DHKey{Private: privkey.Bytes(), Public: pubkey.Bytes()}, nil } +func LoadP256Pubkey(pubkey []byte) (*ecdsa.PublicKey, error) { + if len(pubkey) == 0 { + return nil, errors.New("empty public key") + } + curve := elliptic.P256() + var x, y *big.Int + switch pubkey[0] { + case 0x4: //uncompressed + x, y = elliptic.Unmarshal(curve, pubkey) + case 0x2, 0x3: //compressed + x, y = elliptic.UnmarshalCompressed(curve, pubkey) + default: + return nil, fmt.Errorf("unknown P256 public key type: 0x%x", pubkey[0]) + } + + if x == nil || y == nil { + return nil, errors.New("invalid compressed P256 public key") + } + out := &ecdsa.PublicKey{Curve: curve, X: x, Y: y} + return out, nil +} + +func LoadECDHPubkey(in []byte) (*ecdh.PublicKey, error) { + if len(in) == 0 { + return nil, errors.New("empty public key") + } + if in[0] == 0x4 { //uncompressed + return ecdh.P256().NewPublicKey(in) + } + out, err := LoadP256Pubkey(in) + if err != nil { + return nil, err + } + return out.ECDH() +} + func (c nistCurve) DH(privkey, pubkey []byte) ([]byte, error) { - ecdhPubKey, err := c.curve.NewPublicKey(pubkey) + ecdhPubKey, err := LoadECDHPubkey(pubkey) if err != nil { return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err) } diff --git a/noiseutil/pkcs11.go b/noiseutil/pkcs11.go index d1c7ba918..f0c64333b 100644 --- a/noiseutil/pkcs11.go +++ b/noiseutil/pkcs11.go @@ -31,7 +31,7 @@ func (c nistP11Curve) DH(privkey, pubkey []byte) ([]byte, error) { if !strings.HasPrefix(pkStr, "pkcs11:") { return DHP256.DH(privkey, pubkey) } - ecdhPubKey, err := c.curve.NewPublicKey(pubkey) + ecdhPubKey, err := LoadECDHPubkey(pubkey) if err != nil { return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err) } diff --git a/service/service_test.go b/service/service_test.go index 613758e19..1ee7819bd 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -19,7 +19,7 @@ import ( type m map[string]interface{} func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { - _, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) + _, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}, false) caB, err := caCrt.MarshalPEM() if err != nil { panic(err) @@ -79,7 +79,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n } func TestService(t *testing.T) { - ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}, false) a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{