From e5d768dbcfcad509678897ff457eccc7ef82bbdc Mon Sep 17 00:00:00 2001
From: Eric Chiang <eric.chiang.m@gmail.com>
Date: Wed, 31 Aug 2022 20:41:35 -0700
Subject: [PATCH] oidc: don't parse JWT twice

---
 oidc/jwks.go   | 14 +++++++++++---
 oidc/verify.go |  1 +
 2 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/oidc/jwks.go b/oidc/jwks.go
index 6d2ffee9..fdcfba81 100644
--- a/oidc/jwks.go
+++ b/oidc/jwks.go
@@ -113,15 +113,23 @@ func (i *inflight) result() ([]jose.JSONWebKey, error) {
 	return i.keys, i.err
 }
 
+// paresdJWTKey is a context key that allows common setups to avoid parsing the
+// JWT twice. It holds a *jose.JSONWebSignature value.
+var parsedJWTKey contextKey
+
 // VerifySignature validates a payload against a signature from the jwks_uri.
 //
 // Users MUST NOT call this method directly and should use an IDTokenVerifier
 // instead. This method skips critical validations such as 'alg' values and is
 // only exported to implement the KeySet interface.
 func (r *RemoteKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
-	jws, err := jose.ParseSigned(jwt)
-	if err != nil {
-		return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
+	jws, ok := ctx.Value(parsedJWTKey).(*jose.JSONWebSignature)
+	if !ok {
+		var err error
+		jws, err = jose.ParseSigned(jwt)
+		if err != nil {
+			return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
+		}
 	}
 	return r.verify(ctx, jws)
 }
diff --git a/oidc/verify.go b/oidc/verify.go
index a62fe348..464b61e6 100644
--- a/oidc/verify.go
+++ b/oidc/verify.go
@@ -308,6 +308,7 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
 
 	t.sigAlgorithm = sig.Header.Algorithm
 
+	ctx = context.WithValue(ctx, parsedJWTKey, jws)
 	gotPayload, err := v.keySet.VerifySignature(ctx, rawIDToken)
 	if err != nil {
 		return nil, fmt.Errorf("failed to verify signature: %v", err)