Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(auth): rework gdch support #11263

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions auth/credentials/detect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
Expand All @@ -33,6 +34,14 @@ import (
"cloud.google.com/go/auth/internal/jwt"
)

type tokenRequest struct {
GrantType string `json:"grant_type"`
Audience string `json:"audience"`
SubjectToken string `json:"subject_token"`
SubjectType string `json:"subject_token_type"`
TokenType string `json:"requested_token_type"`
}

type tokResp struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Expand All @@ -54,10 +63,15 @@ func TestDefaultCredentials_GdchServiceAccountKey(t *testing.T) {
if r.Method != "POST" {
t.Errorf("unexpected request method: %v", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Error(err)
b, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
var tokReq tokenRequest
if err := json.Unmarshal(b, &tokReq); err != nil {
t.Fatal(err)
}
parts := strings.Split(r.FormValue("subject_token"), ".")
parts := strings.Split(tokReq.SubjectToken, ".")
var header jwt.Header
var claims jwt.Claims
b, err = base64.RawURLEncoding.DecodeString(parts[0])
Expand All @@ -75,10 +89,10 @@ func TestDefaultCredentials_GdchServiceAccountKey(t *testing.T) {
t.Fatal(err)
}

if got := r.FormValue("audience"); got != aud {
if got := tokReq.Audience; got != aud {
t.Errorf("got audience %v, want %v", got, gdch.GrantType)
}
if want := jwt.HeaderAlgRSA256; header.Algorithm != want {
if want := jwt.HeaderAlgES256; header.Algorithm != want {
t.Errorf("got alg %q, want %q", header.Algorithm, want)
}
if want := jwt.HeaderType; header.Type != want {
Expand Down
36 changes: 24 additions & 12 deletions auth/credentials/internal/gdch/gdch.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package gdch

import (
"bytes"
"context"
"crypto"
"crypto/tls"
Expand All @@ -23,9 +24,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"

"cloud.google.com/go/auth"
Expand Down Expand Up @@ -104,6 +103,14 @@ type gdchProvider struct {
client *http.Client
}

type tokenRequest struct {
GrantType string `json:"grant_type"`
Audience string `json:"audience"`
SubjectToken string `json:"subject_token"`
SubjectType string `json:"subject_token_type"`
TokenType string `json:"requested_token_type"`
}

func (g gdchProvider) Token(ctx context.Context) (*auth.Token, error) {
addCertToTransport(g.client, g.certPool)
iat := time.Now()
Expand All @@ -116,26 +123,31 @@ func (g gdchProvider) Token(ctx context.Context) (*auth.Token, error) {
Exp: exp.Unix(),
}
h := jwt.Header{
Algorithm: jwt.HeaderAlgRSA256,
Algorithm: jwt.HeaderAlgES256,
Type: jwt.HeaderType,
KeyID: string(g.pkID),
KeyID: g.pkID,
}
payload, err := jwt.EncodeJWS(&h, &claims, g.signer)
if err != nil {
return nil, err
}
v := url.Values{}
v.Set("grant_type", GrantType)
v.Set("audience", g.aud)
v.Set("requested_token_type", requestTokenType)
v.Set("subject_token", payload)
v.Set("subject_token_type", subjectTokenType)
tokReq := &tokenRequest{
GrantType: GrantType,
Audience: g.aud,
SubjectToken: payload,
SubjectType: subjectTokenType,
TokenType: requestTokenType,
}
b, err := json.Marshal(tokReq)
if err != nil {
return nil, err
}

req, err := http.NewRequestWithContext(ctx, "POST", g.tokenURL, strings.NewReader(v.Encode()))
req, err := http.NewRequestWithContext(ctx, "POST", g.tokenURL, bytes.NewReader(b))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Content-Type", "application/json")
resp, body, err := internal.DoRequest(g.client, req)
if err != nil {
return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
Expand Down
17 changes: 12 additions & 5 deletions auth/credentials/internal/gdch/gdch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
Expand All @@ -44,10 +45,16 @@ func TestTokenProvider(t *testing.T) {
if r.Method != "POST" {
t.Errorf("unexpected request method: %v", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Error(err)
b, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
parts := strings.Split(r.FormValue("subject_token"), ".")
var tokReq tokenRequest
if err := json.Unmarshal(b, &tokReq); err != nil {
t.Fatal(err)
}

parts := strings.Split(tokReq.SubjectToken, ".")
var header jwt.Header
var claims jwt.Claims
b, err = base64.RawURLEncoding.DecodeString(parts[0])
Expand All @@ -65,10 +72,10 @@ func TestTokenProvider(t *testing.T) {
t.Fatal(err)
}

if got := r.FormValue("audience"); got != aud {
if got := tokReq.Audience; got != aud {
t.Errorf("got audience %v, want %v", got, GrantType)
}
if want := jwt.HeaderAlgRSA256; header.Algorithm != want {
if want := jwt.HeaderAlgES256; header.Algorithm != want {
t.Errorf("got alg %q, want %q", header.Algorithm, want)
}
if want := jwt.HeaderType; header.Type != want {
Expand Down
2 changes: 1 addition & 1 deletion auth/internal/credsfile/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestParseGDCHServiceAccount(t *testing.T) {
FormatVersion: "1",
Project: "fake_project",
Name: "sa_name",
PrivateKey: "-----BEGIN PRIVATE KEY-----\nMIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBALX0PQoe1igW12ikv1bN/r9lN749y2ijmbc/mFHPyS3hNTyOCjDvBbXYbDhQJzWVUikh4mvGBA07qTj79Xc3yBDfKP2IeyYQIFe0t0zkd7R9Zdn98Y2rIQC47aAbDfubtkU1U72t4zL11kHvoa0/RuFZjncvlr42X7be7lYh4p3NAgMBAAECgYASk5wDw4Az2ZkmeuN6Fk/y9H+Lcb2pskJIXjrL533vrDWGOC48LrsThMQPv8cxBky8HFSEklPpkfTF95tpD43iVwJRB/GrCtGTw65IfJ4/tI09h6zGc4yqvIo1cHX/LQ+SxKLGyir/dQM925rGt/VojxY5ryJR7GLbCzxPnJm/oQJBANwOCO6D2hy1LQYJhXh7O+RLtA/tSnT1xyMQsGT+uUCMiKS2bSKx2wxo9k7h3OegNJIu1q6nZ6AbxDK8H3+d0dUCQQDTrPSXagBxzp8PecbaCHjzNRSQE2in81qYnrAFNB4o3DpHyMMY6s5ALLeHKscEWnqP8Ur6X4PvzZecCWU9BKAZAkAutLPknAuxSCsUOvUfS1i87ex77Ot+w6POp34pEX+UWb+u5iFn2cQacDTHLV1LtE80L8jVLSbrbrlH43H0DjU5AkEAgidhycxS86dxpEljnOMCw8CKoUBd5I880IUahEiUltk7OLJYS/Ts1wbn3kPOVX3wyJs8WBDtBkFrDHW2ezth2QJADj3e1YhMVdjJW5jqwlD/VNddGjgzyunmiZg0uOXsHXbytYmsA545S8KRQFaJKFXYYFo2kOjqOiC1T2cAzMDjCQ==\n-----END PRIVATE KEY-----\n",
PrivateKey: "-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIIGb2np7v54Hs6++NiLE7CQtQg7rzm4znstHvrOUlcMMoAoGCCqGSM49\nAwEHoUQDQgAECvv0VyZS9nYOa8tdwKCbkNxlWgrAZVClhJXqrvOZHlH4N3d8Rplk\n2DEJvzp04eMxlHw1jm6JCs3iJR6KAokG+w==\n-----END EC PRIVATE KEY-----\n",
PrivateKeyID: "abcdef1234567890",
CertPath: "cert.pem",
TokenURL: "replace_me",
Expand Down
5 changes: 4 additions & 1 deletion auth/internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ func ParseKey(key []byte) (crypto.Signer, error) {
if err != nil {
parsedKey, err = x509.ParsePKCS1PrivateKey(key)
if err != nil {
return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8: %w", err)
parsedKey, err = x509.ParseECPrivateKey(key)
if err != nil {
return nil, fmt.Errorf("private key should be in a PKCS8, plain PKCS1, or EC format: %w", err)
}
}
}
parsed, ok := parsedKey.(crypto.Signer)
Expand Down
19 changes: 19 additions & 0 deletions auth/internal/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
Expand Down Expand Up @@ -169,3 +170,21 @@
h.Write([]byte(signedContent))
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
}

func VerifyJWSWithEC(token string, key *ecdsa.PublicKey) error {

Check failure on line 174 in auth/internal/jwt/jwt.go

View workflow job for this annotation

GitHub Actions / vet

exported function VerifyJWSWithEC should have comment or be unexported
parts := strings.Split(token, ".")
if len(parts) != 3 {
return errors.New("jwt: invalid token received, token must have 3 parts")
}
sig, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return err
}
signedContent := parts[0] + "." + parts[1]
h := sha256.New()
h.Write([]byte(signedContent))
if valid := ecdsa.VerifyASN1(key, h.Sum(nil), sig); !valid {
return fmt.Errorf("jwt: the ASN.1 encoded signature is invalid")
}
return nil
}
41 changes: 41 additions & 0 deletions auth/internal/jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ package jwt
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"os"
"testing"
"time"
)

func TestSignAndVerifyDecode(t *testing.T) {
Expand Down Expand Up @@ -77,3 +81,40 @@ func TestVerifyFailsOnMalformedClaim(t *testing.T) {
t.Error("got no errors; want improperly formed JWT not to be verified")
}
}

func TestSignandVerifyForEC(t *testing.T) {
key, err := os.ReadFile("../testdata/eckey.pem")
if err != nil {
t.Fatal(err)
}
block, _ := pem.Decode(key)
if block != nil {
key = block.Bytes
}
parsedKey, err := x509.ParseECPrivateKey(key)
if err != nil {
t.Fatal(err)
}

iat := time.Now()
exp := iat.Add(time.Hour)
claims := Claims{
Iss: "iss",
Sub: "sub",
Aud: "audience",
Iat: iat.Unix(),
Exp: exp.Unix(),
}
h := Header{
Algorithm: HeaderAlgES256,
Type: HeaderType,
KeyID: "keyid",
}
token, err := EncodeJWS(&h, &claims, parsedKey)
if err != nil {
t.Fatal(err)
}
if err := VerifyJWSWithEC(token, &parsedKey.PublicKey); err != nil {
t.Fatal(err)
}
}
5 changes: 5 additions & 0 deletions auth/internal/testdata/eckey.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIGb2np7v54Hs6++NiLE7CQtQg7rzm4znstHvrOUlcMMoAoGCCqGSM49
AwEHoUQDQgAECvv0VyZS9nYOa8tdwKCbkNxlWgrAZVClhJXqrvOZHlH4N3d8Rplk
2DEJvzp04eMxlHw1jm6JCs3iJR6KAokG+w==
-----END EC PRIVATE KEY-----
2 changes: 1 addition & 1 deletion auth/internal/testdata/gdch.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"format_version": "1",
"project": "fake_project",
"private_key_id": "abcdef1234567890",
"private_key": "-----BEGIN PRIVATE KEY-----\nMIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBALX0PQoe1igW12ikv1bN/r9lN749y2ijmbc/mFHPyS3hNTyOCjDvBbXYbDhQJzWVUikh4mvGBA07qTj79Xc3yBDfKP2IeyYQIFe0t0zkd7R9Zdn98Y2rIQC47aAbDfubtkU1U72t4zL11kHvoa0/RuFZjncvlr42X7be7lYh4p3NAgMBAAECgYASk5wDw4Az2ZkmeuN6Fk/y9H+Lcb2pskJIXjrL533vrDWGOC48LrsThMQPv8cxBky8HFSEklPpkfTF95tpD43iVwJRB/GrCtGTw65IfJ4/tI09h6zGc4yqvIo1cHX/LQ+SxKLGyir/dQM925rGt/VojxY5ryJR7GLbCzxPnJm/oQJBANwOCO6D2hy1LQYJhXh7O+RLtA/tSnT1xyMQsGT+uUCMiKS2bSKx2wxo9k7h3OegNJIu1q6nZ6AbxDK8H3+d0dUCQQDTrPSXagBxzp8PecbaCHjzNRSQE2in81qYnrAFNB4o3DpHyMMY6s5ALLeHKscEWnqP8Ur6X4PvzZecCWU9BKAZAkAutLPknAuxSCsUOvUfS1i87ex77Ot+w6POp34pEX+UWb+u5iFn2cQacDTHLV1LtE80L8jVLSbrbrlH43H0DjU5AkEAgidhycxS86dxpEljnOMCw8CKoUBd5I880IUahEiUltk7OLJYS/Ts1wbn3kPOVX3wyJs8WBDtBkFrDHW2ezth2QJADj3e1YhMVdjJW5jqwlD/VNddGjgzyunmiZg0uOXsHXbytYmsA545S8KRQFaJKFXYYFo2kOjqOiC1T2cAzMDjCQ==\n-----END PRIVATE KEY-----\n",
"private_key": "-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIIGb2np7v54Hs6++NiLE7CQtQg7rzm4znstHvrOUlcMMoAoGCCqGSM49\nAwEHoUQDQgAECvv0VyZS9nYOa8tdwKCbkNxlWgrAZVClhJXqrvOZHlH4N3d8Rplk\n2DEJvzp04eMxlHw1jm6JCs3iJR6KAokG+w==\n-----END EC PRIVATE KEY-----\n",
"name": "sa_name",
"ca_cert_path": "cert.pem",
"token_uri": "replace_me"
Expand Down
Loading