diff --git a/cert/sign.go b/cert/sign.go index 2f768d4ec..741049d6d 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -11,8 +11,6 @@ import ( "net/netip" "slices" "time" - - "github.com/slackhq/nebula/pkclient" ) // TBSCertificate represents a certificate intended to be signed. @@ -42,22 +40,45 @@ type beingSignedCertificate interface { setSignature([]byte) error } +type SignerLambda func(certBytes []byte) ([]byte, error) + // Sign will create a sealed certificate using details provided by the TBSCertificate as long as those // details do not violate constraints of the signing certificate. // If the TBSCertificate is a CA then signer must be nil. func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) { - return t.sign(signer, curve, key, nil) -} - -func (t *TBSCertificate) SignPkcs11(signer Certificate, curve Curve, client *pkclient.PKClient) (Certificate, error) { - if curve != Curve_P256 { - return nil, fmt.Errorf("only P256 is supported by PKCS#11") + switch t.Curve { + case Curve_CURVE25519: + pk := ed25519.PrivateKey(key) + sp := func(certBytes []byte) ([]byte, error) { + sig := ed25519.Sign(pk, certBytes) + return sig, nil + } + return t.SignWith(signer, curve, sp) + case Curve_P256: + pk := &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + }, + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 + D: new(big.Int).SetBytes(key), + } + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 + pk.X, pk.Y = pk.Curve.ScalarBaseMult(key) + sp := func(certBytes []byte) ([]byte, error) { + // We need to hash first for ECDSA + // - https://pkg.go.dev/crypto/ecdsa#SignASN1 + hashed := sha256.Sum256(certBytes) + return ecdsa.SignASN1(rand.Reader, pk, hashed[:]) + } + return t.SignWith(signer, curve, sp) + default: + return nil, fmt.Errorf("invalid curve: %s", t.Curve) } - - return t.sign(signer, curve, nil, client) } -func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, client *pkclient.PKClient) (Certificate, error) { +// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature. +// You should only use SignWith if you do not have direct access to your private key. +func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) { if curve != t.Curve { return nil, fmt.Errorf("curve in cert and private key supplied don't match") } @@ -112,34 +133,7 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien return nil, err } - var sig []byte - switch t.Curve { - case Curve_CURVE25519: - signer := ed25519.PrivateKey(key) - sig = ed25519.Sign(signer, certBytes) - case Curve_P256: - if client != nil { - sig, err = client.SignASN1(certBytes) - } else { - signer := &ecdsa.PrivateKey{ - PublicKey: ecdsa.PublicKey{ - Curve: elliptic.P256(), - }, - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 - D: new(big.Int).SetBytes(key), - } - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 - signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) - - // We need to hash first for ECDSA - // - https://pkg.go.dev/crypto/ecdsa#SignASN1 - hashed := sha256.Sum256(certBytes) - sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) - } - default: - return nil, fmt.Errorf("invalid curve: %s", t.Curve) - } - + sig, err := sp(certBytes) if err != nil { return nil, err } diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index 7548e2124..f83c94fb4 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -272,7 +272,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error var b []byte if isP11 { - c, err = t.SignPkcs11(nil, curve, p11Client) + c, err = t.SignWith(nil, curve, p11Client.SignASN1) if err != nil { return fmt.Errorf("error while signing with PKCS#11: %w", err) } diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 6ac045214..253ef864e 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -316,7 +316,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("error while signing: %w", err) } } else { - nc, err = t.SignPkcs11(caCert, curve, p11Client) + nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) if err != nil { return fmt.Errorf("error while signing with PKCS#11: %w", err) } @@ -346,7 +346,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("error while signing: %w", err) } } else { - nc, err = t.SignPkcs11(caCert, curve, p11Client) + nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) if err != nil { return fmt.Errorf("error while signing with PKCS#11: %w", err) }