diff --git a/lib/globalflags/globalflags.go b/lib/globalflags/globalflags.go index 2b4157a5..8a34fb95 100644 --- a/lib/globalflags/globalflags.go +++ b/lib/globalflags/globalflags.go @@ -17,6 +17,8 @@ package globalflags import ( "os" + + glog "github.com/golang/glog" /* copybara-comment */ ) var ( @@ -39,4 +41,34 @@ var ( // EnableAWSAdapter is a global flag determining if you want to use enable management of AWS resources. // Set from env var: `export ENABLE_AWS_ADAPTER=true` EnableAWSAdapter = os.Getenv("ENABLE_AWS_ADAPTER") == "true" + + // LocalSignerAlgorithm is a global flag determining if you want to sign the JWT with specific algorithm, only supported in persona service and using local signer. + // It will cause err if given invalid value. + // Set from env var: `export LOCAL_SIGNER_ALGORITHM=RS384` + LocalSignerAlgorithm = parseLocalSignerAlgorithm() +) + +// SignerAlgorithm of JWT. +type SignerAlgorithm string + +const ( + // RS256 used to sign JWT. + RS256 SignerAlgorithm = "RS256" + // RS384 used to sign JWT. + RS384 SignerAlgorithm = "RS384" ) + +func parseLocalSignerAlgorithm() SignerAlgorithm { + s := os.Getenv("LOCAL_SIGNER_ALGORITHM") + switch s { + case "": + return RS256 + case "RS256": + return RS256 + case "RS384": + return RS384 + default: + glog.Fatalf("invalid value of LOCAL_SIGNER_ALGORITHM: %s", s) + } + return SignerAlgorithm("") +} diff --git a/lib/kms/localsign/signer.go b/lib/kms/localsign/signer.go index 75a2faf1..385975d2 100644 --- a/lib/kms/localsign/signer.go +++ b/lib/kms/localsign/signer.go @@ -27,20 +27,36 @@ import ( // Signer can sign jwt. type Signer struct { - key jose.JSONWebKey - pri *rsa.PrivateKey + key jose.JSONWebKey + pri *rsa.PrivateKey + algo jose.SignatureAlgorithm } -// New Signer with given key. +// New RS256 Signer with given key. func New(k *testkeys.Key) *Signer { return &Signer{ key: jose.JSONWebKey{ Key: k.Public, - Algorithm: "RS256", + Algorithm: string(jose.RS256), Use: "sig", KeyID: k.ID, }, - pri: k.Private, + pri: k.Private, + algo: jose.RS256, + } +} + +// NewRS384Signer use RS384 to sign jwt +func NewRS384Signer(k *testkeys.Key) *Signer { + return &Signer{ + key: jose.JSONWebKey{ + Key: k.Public, + Algorithm: string(jose.RS384), + Use: "sig", + KeyID: k.ID, + }, + pri: k.Private, + algo: jose.RS384, } } @@ -54,7 +70,7 @@ func (s *Signer) PublicKeys() *jose.JSONWebKeySet { // SignJWT signs the given claims return the jwt string. func (s *Signer) SignJWT(ctx context.Context, claims interface{}, header map[string]string) (string, error) { key := jose.SigningKey{ - Algorithm: jose.RS256, + Algorithm: s.algo, Key: s.pri, } diff --git a/lib/persona/broker.go b/lib/persona/broker.go index 6e08ac21..76edf888 100644 --- a/lib/persona/broker.go +++ b/lib/persona/broker.go @@ -30,6 +30,7 @@ import ( "gopkg.in/square/go-jose.v2" /* copybara-comment */ "github.com/pborman/uuid" /* copybara-comment */ "github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/ga4gh" /* copybara-comment: ga4gh */ + "github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/globalflags" /* copybara-comment: globalflags */ "github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/httputils" /* copybara-comment: httputils */ "github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/kms/localsign" /* copybara-comment: localsign */ "github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/srcutil" /* copybara-comment: srcutil */ @@ -95,7 +96,8 @@ func NewBroker(issuerURL string, key *testkeys.Key, service, path string, useOID // Sign the jwt with the private key in Server. func (s *Server) Sign(header map[string]string, claim interface{}) (string, error) { - signer := localsign.New(s.key) + signer := signer(s.key) + return signer.SignJWT(context.Background(), claim, header) } @@ -388,3 +390,10 @@ func registerHandlers(r *mux.Router, s *Server, useOIDCPrefix bool) { sfs := http.StripPrefix(staticFilePath, http.FileServer(http.Dir(srcutil.Path(staticDirectory)))) r.PathPrefix(staticFilePath).Handler(sfs) } + +func signer(key *testkeys.Key) *localsign.Signer { + if globalflags.LocalSignerAlgorithm == globalflags.RS384 { + return localsign.NewRS384Signer(key) + } + return localsign.New(key) +} diff --git a/lib/persona/persona.go b/lib/persona/persona.go index ceef5f2a..4938c5cf 100644 --- a/lib/persona/persona.go +++ b/lib/persona/persona.go @@ -22,7 +22,6 @@ import ( "time" "github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/ga4gh" /* copybara-comment: ga4gh */ - "github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/kms/localsign" /* copybara-comment: localsign */ "github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/testkeys" /* copybara-comment: testkeys */ "github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/timeutil" /* copybara-comment: timeutil */ @@ -110,7 +109,7 @@ func NewAccessToken(name, issuer, clientID, scope string, persona *cpb.TestPerso } ctx := context.Background() - signer := localsign.New(&personaKey) + signer := signer(&personaKey) access, err := ga4gh.NewAccessFromData(ctx, d, signer) if err != nil { @@ -265,7 +264,7 @@ func populatePersonaVisas(ctx context.Context, pname, visaIssuer string, asserti id.GA4GH = make(map[string][]ga4gh.OldClaim) id.VisaJWTs = make([]string, len(assertions)) now := float64(time.Now().Unix()) - signer := localsign.New(&personaKey) + signer := signer(&personaKey) for i, assert := range assertions { typ := ga4gh.Type(assert.Type) diff --git a/lib/verifier/oidc_jwt_sig_verifier.go b/lib/verifier/oidc_jwt_sig_verifier.go index 2012d24b..2eaa3474 100644 --- a/lib/verifier/oidc_jwt_sig_verifier.go +++ b/lib/verifier/oidc_jwt_sig_verifier.go @@ -42,8 +42,9 @@ func newOIDCSigVerifier(ctx context.Context, issuer string) (*oidcJwtSigVerifier // Skip client claims check if no client claims passed in. SkipClientIDCheck: true, // Expire check and issuer check will do explicitly. - SkipExpiryCheck: true, - SkipIssuerCheck: true, + SkipExpiryCheck: true, + SkipIssuerCheck: true, + SupportedSigningAlgs: []string{oidc.RS256, oidc.RS384, oidc.ES384}, }) return &oidcJwtSigVerifier{ diff --git a/lib/verifier/oidc_jwt_sig_verifier_test.go b/lib/verifier/oidc_jwt_sig_verifier_test.go index ec897471..dbf6bf70 100644 --- a/lib/verifier/oidc_jwt_sig_verifier_test.go +++ b/lib/verifier/oidc_jwt_sig_verifier_test.go @@ -40,31 +40,49 @@ func TestOIDCVerifier_Verify(t *testing.T) { ExpiresAt: time.Now().Add(time.Hour).Unix(), }, } - signer := localsign.New(&key) - visa, err := ga4gh.NewVisaFromData(context.Background(), d, ga4gh.JWTEmptyJKU, signer) - if err != nil { - t.Fatalf("ga4gh.NewVisaFromData() failed: %v", err) + + tests := []struct { + name string + signer *localsign.Signer + }{ + { + name: "RS256", + signer: localsign.New(&key), + }, + { + name: "RS384", + signer: localsign.NewRS384Signer(&key), + }, } - // Make calls by oidc package use the fake HTTP client. - ctx := oidc.ClientContext(context.Background(), f.HTTP.Client) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + visa, err := ga4gh.NewVisaFromData(context.Background(), d, ga4gh.JWTEmptyJKU, tc.signer) + if err != nil { + t.Fatalf("ga4gh.NewVisaFromData() failed: %v", err) + } - pv, err := NewPassportVerifier(ctx, f.Issuer0.URL, client) - if err != nil { - t.Fatalf("NewPassportVerifier() failed: %v", err) - } + // Make calls by oidc package use the fake HTTP client. + ctx := oidc.ClientContext(context.Background(), f.HTTP.Client) - vv, err := NewVisaVerifier(ctx, f.Issuer0.URL, "", client) - if err != nil { - t.Fatalf("NewVisaVerifier() failed: %v", err) - } + pv, err := NewPassportVerifier(ctx, f.Issuer0.URL, client) + if err != nil { + t.Fatalf("NewPassportVerifier() failed: %v", err) + } - if err := pv.Verify(ctx, string(visa.JWT())); err != nil { - t.Errorf("VerifyPassportToken() failed: %v", err) - } + vv, err := NewVisaVerifier(ctx, f.Issuer0.URL, "", client) + if err != nil { + t.Fatalf("NewVisaVerifier() failed: %v", err) + } - if err := vv.Verify(ctx, string(visa.JWT()), ""); err != nil { - t.Errorf("VerifyPassportToken() failed: %v", err) + if err := pv.Verify(ctx, string(visa.JWT())); err != nil { + t.Errorf("VerifyPassportToken() failed: %v", err) + } + + if err := vv.Verify(ctx, string(visa.JWT()), ""); err != nil { + t.Errorf("VerifyPassportToken() failed: %v", err) + } + }) } }