diff --git a/embedx/config.schema.json b/embedx/config.schema.json index a836c21a7af5..65eacf9cfbd4 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -436,7 +436,8 @@ "dingtalk", "patreon", "linkedin", - "lark" + "lark", + "x" ], "examples": ["google"] }, diff --git a/go.mod b/go.mod index 22e8187de37e..41effb40c1c8 100644 --- a/go.mod +++ b/go.mod @@ -327,6 +327,7 @@ require ( require ( github.com/coreos/go-oidc/v3 v3.9.0 + github.com/dghubble/oauth1 v0.7.2 github.com/lestrrat-go/jwx/v2 v2.0.19 ) diff --git a/go.sum b/go.sum index d559b4b6ef35..9f68475d07fc 100644 --- a/go.sum +++ b/go.sum @@ -148,6 +148,8 @@ github.com/davidrjonas/semver-cli v0.0.0-20190116233701-ee19a9a0dda6/go.mod h1:+ github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= +github.com/dghubble/oauth1 v0.7.2 h1:pwcinOZy8z6XkNxvPmUDY52M7RDPxt0Xw1zgZ6Cl5JA= +github.com/dghubble/oauth1 v0.7.2/go.mod h1:9erQdIhqhOHG/7K9s/tgh9Ks/AfoyrO5mW/43Lu2+kE= github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= diff --git a/identity/credentials_oidc.go b/identity/credentials_oidc.go index 09b5d0aecad0..27462f927024 100644 --- a/identity/credentials_oidc.go +++ b/identity/credentials_oidc.go @@ -32,8 +32,36 @@ type CredentialsOIDCProvider struct { Organization string `json:"organization,omitempty"` } +// swagger:ignore +type CredentialsOIDCEncryptedTokens struct { + RefreshToken string `json:"refresh_token,omitempty"` + IDToken string `json:"id_token,omitempty"` + AccessToken string `json:"access_token,omitempty"` +} + +func (c *CredentialsOIDCEncryptedTokens) GetRefreshToken() string { + if c == nil { + return "" + } + return c.RefreshToken +} + +func (c *CredentialsOIDCEncryptedTokens) GetAccessToken() string { + if c == nil { + return "" + } + return c.AccessToken +} + +func (c *CredentialsOIDCEncryptedTokens) GetIDToken() string { + if c == nil { + return "" + } + return c.IDToken +} + // NewCredentialsOIDC creates a new OIDC credential. -func NewCredentialsOIDC(idToken, accessToken, refreshToken, provider, subject, organization string) (*Credentials, error) { +func NewCredentialsOIDC(tokens *CredentialsOIDCEncryptedTokens, provider, subject, organization string) (*Credentials, error) { if provider == "" { return nil, errors.New("received empty provider in oidc credentials") } @@ -48,9 +76,9 @@ func NewCredentialsOIDC(idToken, accessToken, refreshToken, provider, subject, o { Subject: subject, Provider: provider, - InitialIDToken: idToken, - InitialAccessToken: accessToken, - InitialRefreshToken: refreshToken, + InitialIDToken: tokens.GetIDToken(), + InitialAccessToken: tokens.GetAccessToken(), + InitialRefreshToken: tokens.GetRefreshToken(), Organization: organization, }}, }); err != nil { @@ -65,6 +93,14 @@ func NewCredentialsOIDC(idToken, accessToken, refreshToken, provider, subject, o }, nil } +func (c *CredentialsOIDCProvider) GetTokens() *CredentialsOIDCEncryptedTokens { + return &CredentialsOIDCEncryptedTokens{ + RefreshToken: c.InitialRefreshToken, + IDToken: c.InitialIDToken, + AccessToken: c.InitialAccessToken, + } +} + func OIDCUniqueID(provider, subject string) string { return fmt.Sprintf("%s:%s", provider, subject) } diff --git a/identity/credentials_oidc_test.go b/identity/credentials_oidc_test.go index dda20f73deb7..f52a077d107c 100644 --- a/identity/credentials_oidc_test.go +++ b/identity/credentials_oidc_test.go @@ -10,10 +10,10 @@ import ( ) func TestNewCredentialsOIDC(t *testing.T) { - _, err := NewCredentialsOIDC("", "", "", "", "not-empty", "") + _, err := NewCredentialsOIDC(new(CredentialsOIDCEncryptedTokens), "", "not-empty", "") require.Error(t, err) - _, err = NewCredentialsOIDC("", "", "", "not-empty", "", "") + _, err = NewCredentialsOIDC(new(CredentialsOIDCEncryptedTokens), "not-empty", "", "") require.Error(t, err) - _, err = NewCredentialsOIDC("", "", "", "not-empty", "not-empty", "") + _, err = NewCredentialsOIDC(new(CredentialsOIDCEncryptedTokens), "not-empty", "not-empty", "") require.NoError(t, err) } diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/selfservice/flow/login/hook_test.go b/selfservice/flow/login/hook_test.go index 46f957c1ca76..fe73f22d7eef 100644 --- a/selfservice/flow/login/hook_test.go +++ b/selfservice/flow/login/hook_test.go @@ -300,9 +300,7 @@ func TestLoginExecutor(t *testing.T) { require.NoError(t, reg.Persister().CreateIdentity(context.Background(), useIdentity)) credsOIDC, err := identity.NewCredentialsOIDC( - "id-token", - "access-token", - "refresh-token", + &identity.CredentialsOIDCEncryptedTokens{IDToken: "id-token", AccessToken: "access-token", RefreshToken: "refresh-token"}, "my-provider", email, "", diff --git a/selfservice/strategy/oidc/provider.go b/selfservice/strategy/oidc/provider.go index 4fe9a028c11e..cb22ebb6b847 100644 --- a/selfservice/strategy/oidc/provider.go +++ b/selfservice/strategy/oidc/provider.go @@ -5,8 +5,10 @@ package oidc import ( "context" + "net/http" "net/url" + "github.com/dghubble/oauth1" "github.com/pkg/errors" "github.com/ory/herodot" @@ -18,12 +20,24 @@ import ( type Provider interface { Config() *Configuration +} + +type OAuth2Provider interface { + Provider + AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption OAuth2(ctx context.Context) (*oauth2.Config, error) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) - AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption } -type TokenExchanger interface { +type OAuth1Provider interface { + Provider + OAuth1(ctx context.Context) *oauth1.Config + AuthURL(ctx context.Context, state string) (string, error) + Claims(ctx context.Context, token *oauth1.Token) (*Claims, error) + ExchangeToken(ctx context.Context, req *http.Request) (*oauth1.Token, error) +} + +type OAuth2TokenExchanger interface { Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) } @@ -87,11 +101,11 @@ func (c *Claims) Validate() error { // - `hd` (string): The `hd` parameter limits the login/registration process to a Google Organization, e.g. `mycollege.edu`. // - `prompt` (string): The `prompt` specifies whether the Authorization Server prompts the End-User for reauthentication and consent, e.g. `select_account`. // - `auth_type` (string): The `auth_type` parameter specifies the requested authentication features (as a comma-separated list), e.g. `reauthenticate`. -func UpstreamParameters(provider Provider, upstreamParameters map[string]string) []oauth2.AuthCodeOption { +func UpstreamParameters(upstreamParameters map[string]string) []oauth2.AuthCodeOption { // validation of upstream parameters are already handled in the `oidc/.schema/link.schema.json` and `oidc/.schema/settings.schema.json` file. // `upstreamParameters` will always only contain allowed parameters based on the configuration. - // we double check the parameters here to prevent any potential security issues. + // we double-check the parameters here to prevent any potential security issues. allowedParameters := map[string]struct{}{ "login_hint": {}, "hd": {}, diff --git a/selfservice/strategy/oidc/provider_config.go b/selfservice/strategy/oidc/provider_config.go index 512a5ecc6fa2..ea4d4364691e 100644 --- a/selfservice/strategy/oidc/provider_config.go +++ b/selfservice/strategy/oidc/provider_config.go @@ -160,6 +160,7 @@ var supportedProviders = map[string]func(config *Configuration, reg Dependencies "linkedin": NewProviderLinkedIn, "patreon": NewProviderPatreon, "lark": NewProviderLark, + "x": NewProviderX, } func (c ConfigurationCollection) Provider(id string, reg Dependencies) (Provider, error) { diff --git a/selfservice/strategy/oidc/provider_dingtalk.go b/selfservice/strategy/oidc/provider_dingtalk.go index 36469e6e7c23..12abffe85942 100644 --- a/selfservice/strategy/oidc/provider_dingtalk.go +++ b/selfservice/strategy/oidc/provider_dingtalk.go @@ -65,7 +65,7 @@ func (g *ProviderDingTalk) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx), nil } -func (g *ProviderDingTalk) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { +func (g *ProviderDingTalk) ExchangeOAuth2Token(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { conf, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) diff --git a/selfservice/strategy/oidc/provider_generic_test.go b/selfservice/strategy/oidc/provider_generic_test.go index 221ca7b1827a..7c90da7e3ec5 100644 --- a/selfservice/strategy/oidc/provider_generic_test.go +++ b/selfservice/strategy/oidc/provider_generic_test.go @@ -45,9 +45,9 @@ func makeAuthCodeURL(t *testing.T, r *login.Flow, reg *driver.RegistryDefault) s Mapper: "file://./stub/hydra.schema.json", RequestedClaims: makeOIDCClaims(), }, reg) - c, err := p.OAuth2(context.Background()) + c, err := p.(oidc.OAuth2Provider).OAuth2(context.Background()) require.NoError(t, err) - return c.AuthCodeURL("state", p.AuthCodeURLOptions(r)...) + return c.AuthCodeURL("state", p.(oidc.OAuth2Provider).AuthCodeURLOptions(r)...) } func TestProviderGenericOIDC_AddAuthCodeURLOptions(t *testing.T) { diff --git a/selfservice/strategy/oidc/provider_google_test.go b/selfservice/strategy/oidc/provider_google_test.go index d900b088d358..c1a1f65b9348 100644 --- a/selfservice/strategy/oidc/provider_google_test.go +++ b/selfservice/strategy/oidc/provider_google_test.go @@ -34,7 +34,7 @@ func TestProviderGoogle_Scope(t *testing.T) { Scope: []string{"email", "profile", "offline_access"}, }, reg) - c, _ := p.OAuth2(context.Background()) + c, _ := p.(oidc.OAuth2Provider).OAuth2(context.Background()) assert.NotContains(t, c.Scopes, "offline_access") } @@ -55,7 +55,7 @@ func TestProviderGoogle_AccessType(t *testing.T) { ID: x.NewUUID(), } - options := p.AuthCodeURLOptions(r) + options := p.(oidc.OAuth2Provider).AuthCodeURLOptions(r) assert.Contains(t, options, oauth2.AccessTypeOffline) } diff --git a/selfservice/strategy/oidc/provider_linkedin_test.go b/selfservice/strategy/oidc/provider_linkedin_test.go index 5eb7a629c7cc..d5b9df86d25a 100644 --- a/selfservice/strategy/oidc/provider_linkedin_test.go +++ b/selfservice/strategy/oidc/provider_linkedin_test.go @@ -115,7 +115,7 @@ func TestProviderLinkedin_Claims(t *testing.T) { linkedin := oidc.NewProviderLinkedIn(c, reg) const fakeLinkedinIDToken = "id_token_mock_" - actual, err := linkedin.Claims( + actual, err := linkedin.(oidc.OAuth2Provider).Claims( context.Background(), (&oauth2.Token{AccessToken: "foo", Expiry: time.Now().Add(time.Hour)}).WithExtra(map[string]interface{}{"id_token": fakeLinkedinIDToken}), url.Values{}, @@ -191,7 +191,7 @@ func TestProviderLinkedin_No_Picture(t *testing.T) { linkedin := oidc.NewProviderLinkedIn(c, reg) const fakeLinkedinIDToken = "id_token_mock_" - actual, err := linkedin.Claims( + actual, err := linkedin.(oidc.OAuth2Provider).Claims( context.Background(), (&oauth2.Token{AccessToken: "foo", Expiry: time.Now().Add(time.Hour)}).WithExtra(map[string]interface{}{"id_token": fakeLinkedinIDToken}), url.Values{}, diff --git a/selfservice/strategy/oidc/provider_private_net_test.go b/selfservice/strategy/oidc/provider_private_net_test.go index 878b8622e993..e656ee0462bb 100644 --- a/selfservice/strategy/oidc/provider_private_net_test.go +++ b/selfservice/strategy/oidc/provider_private_net_test.go @@ -84,10 +84,11 @@ func TestProviderPrivateIP(t *testing.T) { // VK uses a fixed token URL and does not use the issuer. // Yandex uses a fixed token URL and does not use the issuer. // NetID uses a fixed token URL and does not use the issuer. + // X uses a fixed token URL and userinfoRL and does not use the issuer value. } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { p := tc.p(tc.c) - _, err := p.Claims(context.Background(), (&oauth2.Token{RefreshToken: "foo", Expiry: time.Now().Add(-time.Hour)}).WithExtra(map[string]interface{}{ + _, err := p.(oidc.OAuth2Provider).Claims(context.Background(), (&oauth2.Token{RefreshToken: "foo", Expiry: time.Now().Add(-time.Hour)}).WithExtra(map[string]interface{}{ "id_token": tc.id, }), url.Values{}) require.Error(t, err) diff --git a/selfservice/strategy/oidc/provider_userinfo_test.go b/selfservice/strategy/oidc/provider_userinfo_test.go index 0b11f2dcae90..97456dfc404d 100644 --- a/selfservice/strategy/oidc/provider_userinfo_test.go +++ b/selfservice/strategy/oidc/provider_userinfo_test.go @@ -343,7 +343,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { return resp, err }) - _, err := tc.provider.Claims(ctx, token, url.Values{}) + _, err := tc.provider.(oidc.OAuth2Provider).Claims(ctx, token, url.Values{}) var he *herodot.DefaultError require.ErrorAs(t, err, &he) assert.Equal(t, "OpenID Connect provider returned a 455 status code but 200 is expected.", he.Reason()) @@ -359,7 +359,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { httpmock.RegisterResponder("GET", tc.userInfoEndpoint, tc.userInfoHandler) - claims, err := tc.provider.Claims(ctx, token, url.Values{}) + claims, err := tc.provider.(oidc.OAuth2Provider).Claims(ctx, token, url.Values{}) require.NoError(t, err) if tc.expectedClaims == nil { assert.Equal(t, expectedClaims, claims) diff --git a/selfservice/strategy/oidc/provider_x.go b/selfservice/strategy/oidc/provider_x.go new file mode 100644 index 000000000000..060ba58a6303 --- /dev/null +++ b/selfservice/strategy/oidc/provider_x.go @@ -0,0 +1,164 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oidc + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/dghubble/oauth1" + "github.com/dghubble/oauth1/twitter" + "github.com/pkg/errors" + + "github.com/ory/herodot" +) + +var _ Provider = (*ProviderX)(nil) +var _ OAuth1Provider = (*ProviderX)(nil) + +const xUserInfoBase = "https://api.twitter.com/1.1/account/verify_credentials.json" +const xUserInfoWithEmail = xUserInfoBase + "?include_email=true" + +type ProviderX struct { + config *Configuration + reg Dependencies +} + +func (p *ProviderX) Config() *Configuration { + return p.config +} + +func NewProviderX( + config *Configuration, + reg Dependencies) Provider { + return &ProviderX{ + config: config, + reg: reg, + } +} + +func (p *ProviderX) ExchangeToken(ctx context.Context, req *http.Request) (*oauth1.Token, error) { + requestToken, verifier, err := oauth1.ParseAuthorizationCallback(req) + if err != nil { + return nil, err + } + + accessToken, accessSecret, err := p.OAuth1(ctx).AccessToken(requestToken, "", verifier) + if err != nil { + return nil, err + } + + return oauth1.NewToken(accessToken, accessSecret), nil +} + +func (p *ProviderX) AuthURL(ctx context.Context, state string) (string, error) { + c := p.OAuth1(ctx) + + // We need to cheat so that callback validates on return + c.CallbackURL = c.CallbackURL + fmt.Sprintf("?state=%s&code=unused", state) + + requestToken, _, err := c.RequestToken() + if err != nil { + return "", errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`Unable to sign in with X because the OAuth1 request token could not be initialized.`)) + } + + authzURL, err := c.AuthorizationURL(requestToken) + if err != nil { + return "", errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`Unable to sign in with X because the OAuth1 authorization URL could not be parsed.`)) + } + + return authzURL.String(), nil +} + +func (p *ProviderX) CheckError(ctx context.Context, r *http.Request) error { + if r.URL.Query().Get("denied") == "" { + return nil + } + + return errors.WithStack(herodot.ErrBadRequest.WithReasonf(`Unable to sign in with X because the user denied the request.`)) +} + +func (p *ProviderX) OAuth1(ctx context.Context) *oauth1.Config { + return &oauth1.Config{ + ConsumerKey: p.config.ClientID, + ConsumerSecret: p.config.ClientSecret, + Endpoint: twitter.AuthorizeEndpoint, + CallbackURL: p.config.Redir(p.reg.Config().OIDCRedirectURIBase(ctx)), + } +} + +func (p *ProviderX) userInfoEndpoint() string { + for _, scope := range p.config.Scope { + if scope == "email" { + return xUserInfoWithEmail + } + } + + return xUserInfoBase +} + +func (p *ProviderX) Claims(ctx context.Context, token *oauth1.Token) (*Claims, error) { + ctx = context.WithValue(ctx, oauth1.HTTPClient, p.reg.HTTPClient(ctx).HTTPClient) + + c := p.OAuth1(ctx) + client := c.Client(ctx, token) + endpoint := p.userInfoEndpoint() + + resp, err := client.Get(endpoint) + if err != nil { + return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) + } + defer resp.Body.Close() + + if err := logUpstreamError(p.reg.Logger(), resp); err != nil { + return nil, err + } + + user := &xUser{} + if err := json.NewDecoder(resp.Body).Decode(user); err != nil { + return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) + } + + website := "" + if user.URL != nil { + website = *user.URL + } + + return &Claims{ + Issuer: endpoint, + Subject: user.IDStr, + Name: user.Name, + Picture: user.ProfileImageURLHTTPS, + Email: user.Email, + PreferredUsername: user.ScreenName, + Website: website, + }, nil +} + +type xUser struct { + ID int `json:"id"` + IDStr string `json:"id_str"` + Name string `json:"name"` + ScreenName string `json:"screen_name"` + Location string `json:"location"` + Description string `json:"description"` + URL *string `json:"url,omitempty"` + Protected bool `json:"protected"` + FollowersCount int `json:"followers_count"` + FriendsCount int `json:"friends_count"` + ListedCount int `json:"listed_count"` + CreatedAt string `json:"created_at"` + FavouritesCount int `json:"favourites_count"` + Verified bool `json:"verified"` + StatusesCount int `json:"statuses_count"` + DefaultProfile bool `json:"default_profile"` + DefaultProfileImage bool `json:"default_profile_image"` + ProfileImageURLHTTPS string `json:"profile_image_url_https"` + WithheldInCountries []string `json:"withheld_in_countries"` + Suspended bool `json:"suspended"` + NeedsPhoneVerification bool `json:"needs_phone_verification"` + Email string `json:"email"` +} diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 4f5848e8a431..5289b2a9b96b 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -412,16 +412,39 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt return } - token, err := s.ExchangeCode(r.Context(), provider, code) - if err != nil { - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) - return - } + var claims *Claims + var et *identity.CredentialsOIDCEncryptedTokens + switch p := provider.(type) { + case OAuth2Provider: + token, err := s.ExchangeCode(r.Context(), provider, code) + if err != nil { + s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + return + } - claims, err := provider.Claims(r.Context(), token, r.URL.Query()) - if err != nil { - s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) - return + et, err = s.encryptOAuth2Tokens(r.Context(), token) + if err != nil { + s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + return + } + + claims, err = p.Claims(r.Context(), token, r.URL.Query()) + if err != nil { + s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + return + } + case OAuth1Provider: + token, err := p.ExchangeToken(r.Context(), r) + if err != nil { + s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + return + } + + claims, err = p.Claims(r.Context(), token) + if err != nil { + s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + return + } } if err := claims.Validate(); err != nil { @@ -431,7 +454,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt switch a := req.(type) { case *login.Flow: - if ff, err := s.processLogin(w, r, a, token, claims, provider, cntnr); err != nil { + if ff, err := s.processLogin(w, r, a, et, claims, provider, cntnr); err != nil { if ff != nil { s.forwardError(w, r, ff, err) return @@ -441,7 +464,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt return case *registration.Flow: a.TransientPayload = cntnr.TransientPayload - if ff, err := s.processRegistration(w, r, a, token, claims, provider, cntnr, ""); err != nil { + if ff, err := s.processRegistration(w, r, a, et, claims, provider, cntnr, ""); err != nil { if ff != nil { s.forwardError(w, r, ff, err) return @@ -455,7 +478,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt s.forwardError(w, r, a, s.handleError(w, r, a, pid, nil, err)) return } - if err := s.linkProvider(w, r, &settings.UpdateContext{Session: sess, Flow: a}, token, claims, provider); err != nil { + if err := s.linkProvider(w, r, &settings.UpdateContext{Session: sess, Flow: a}, et, claims, provider); err != nil { s.forwardError(w, r, a, s.handleError(w, r, a, pid, nil, err)) return } @@ -473,18 +496,23 @@ func (s *Strategy) ExchangeCode(ctx context.Context, provider Provider, code str span.SetAttributes(attribute.String("provider_id", provider.Config().ID)) span.SetAttributes(attribute.String("provider_label", provider.Config().Label)) - te, ok := provider.(TokenExchanger) - if !ok { - te, err = provider.OAuth2(ctx) - if err != nil { - return nil, err + switch p := provider.(type) { + case OAuth2Provider: + te, ok := provider.(OAuth2TokenExchanger) + if !ok { + te, err = p.OAuth2(ctx) + if err != nil { + return nil, err + } } - } - client := s.d.HTTPClient(ctx) - ctx = context.WithValue(ctx, oauth2.HTTPClient, client.HTTPClient) - token, err = te.Exchange(ctx, code) - return token, err + client := s.d.HTTPClient(ctx) + ctx = context.WithValue(ctx, oauth2.HTTPClient, client.HTTPClient) + token, err = te.Exchange(ctx, code) + return token, err + default: + return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The chosen provider is not capable of exchanging an OAuth 2.0 code for an access token.")) + } } func (s *Strategy) populateMethod(r *http.Request, f flow.Flow, message func(provider string) *text.Message) error { @@ -706,7 +734,7 @@ func (s *Strategy) processIDToken(w http.ResponseWriter, r *http.Request, provid return claims, nil } -func (s *Strategy) linkCredentials(ctx context.Context, i *identity.Identity, idToken, accessToken, refreshToken, provider, subject, organization string) error { +func (s *Strategy) linkCredentials(ctx context.Context, i *identity.Identity, tokens *identity.CredentialsOIDCEncryptedTokens, provider, subject, organization string) error { if err := s.d.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, i, identity.ExpandCredentials); err != nil { return err } @@ -714,7 +742,7 @@ func (s *Strategy) linkCredentials(ctx context.Context, i *identity.Identity, id creds, err := i.ParseCredentials(s.ID(), &conf) if errors.Is(err, herodot.ErrNotFound) { var err error - if creds, err = identity.NewCredentialsOIDC(idToken, accessToken, refreshToken, provider, subject, organization); err != nil { + if creds, err = identity.NewCredentialsOIDC(tokens, provider, subject, organization); err != nil { return err } } else if err != nil { @@ -723,9 +751,9 @@ func (s *Strategy) linkCredentials(ctx context.Context, i *identity.Identity, id creds.Identifiers = append(creds.Identifiers, identity.OIDCUniqueID(provider, subject)) conf.Providers = append(conf.Providers, identity.CredentialsOIDCProvider{ Subject: subject, Provider: provider, - InitialAccessToken: accessToken, - InitialRefreshToken: refreshToken, - InitialIDToken: idToken, + InitialAccessToken: tokens.GetAccessToken(), + InitialRefreshToken: tokens.GetRefreshToken(), + InitialIDToken: tokens.GetIDToken(), Organization: organization, }) @@ -742,3 +770,45 @@ func (s *Strategy) linkCredentials(ctx context.Context, i *identity.Identity, id return nil } + +func getAuthRedirectURL(ctx context.Context, provider Provider, req ider, state *State, upstreamParameters map[string]string) (codeURL string, err error) { + switch p := provider.(type) { + case OAuth2Provider: + c, err := p.OAuth2(ctx) + if err != nil { + return "", err + } + + return c.AuthCodeURL(state.String(), append(UpstreamParameters(upstreamParameters), p.AuthCodeURLOptions(req)...)...), nil + case OAuth1Provider: + return p.AuthURL(ctx, state.String()) + default: + return "", errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The provider %s does not support the OAuth 2.0 or OAuth 1.0 protocol", provider.Config().Provider)) + } +} + +func (s *Strategy) encryptOAuth2Tokens(ctx context.Context, token *oauth2.Token) (et *identity.CredentialsOIDCEncryptedTokens, err error) { + et = new(identity.CredentialsOIDCEncryptedTokens) + if token == nil { + return et, nil + } + + if idToken, ok := token.Extra("id_token").(string); ok { + et.IDToken, err = s.d.Cipher(ctx).Encrypt(ctx, []byte(idToken)) + if err != nil { + return nil, err + } + } + + et.AccessToken, err = s.d.Cipher(ctx).Encrypt(ctx, []byte(token.AccessToken)) + if err != nil { + return nil, err + } + + et.RefreshToken, err = s.d.Cipher(ctx).Encrypt(ctx, []byte(token.RefreshToken)) + if err != nil { + return nil, err + } + + return et, nil +} diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index fe90eab8ad4e..09bb8ab27e6e 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -11,7 +11,6 @@ import ( "time" "github.com/julienschmidt/httprouter" - "golang.org/x/oauth2" "github.com/ory/kratos/session" @@ -107,7 +106,7 @@ type UpdateLoginFlowWithOidcMethod struct { TransientPayload json.RawMessage `json:"transient_payload,omitempty" form:"transient_payload"` } -func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *oauth2.Token, claims *Claims, provider Provider, container *AuthCodeContainer) (*registration.Flow, error) { +func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider, container *AuthCodeContainer) (*registration.Flow, error) { i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject)) if err != nil { if errors.Is(err, sqlcon.ErrNoRows) { @@ -227,11 +226,6 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, s.handleError(w, r, f, pid, nil, err) } - c, err := provider.OAuth2(ctx) - if err != nil { - return nil, s.handleError(w, r, f, pid, nil, err) - } - req, err := s.validateFlow(ctx, r, f.ID) if err != nil { return nil, s.handleError(w, r, f, pid, nil, err) @@ -282,7 +276,11 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, err } - codeURL := c.AuthCodeURL(state.String(), append(UpstreamParameters(provider, up), provider.AuthCodeURLOptions(req)...)...) + codeURL, err := getAuthRedirectURL(ctx, provider, f, state, up) + if err != nil { + return nil, s.handleError(w, r, f, pid, nil, err) + } + if x.IsJSONRequest(r) { s.d.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(codeURL)) } else { diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index f16ba63f07b4..124e8539f6ea 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -16,7 +16,6 @@ import ( "github.com/pkg/errors" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "golang.org/x/oauth2" "github.com/ory/herodot" "github.com/ory/kratos/continuity" @@ -185,11 +184,6 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat return s.handleError(w, r, f, pid, nil, err) } - c, err := provider.OAuth2(ctx) - if err != nil { - return s.handleError(w, r, f, pid, nil, err) - } - req, err := s.validateFlow(ctx, r, f.ID) if err != nil { return s.handleError(w, r, f, pid, nil, err) @@ -237,7 +231,10 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat return err } - codeURL := c.AuthCodeURL(state.String(), append(UpstreamParameters(provider, up), provider.AuthCodeURLOptions(req)...)...) + codeURL, err := getAuthRedirectURL(ctx, provider, f, state, up) + if err != nil { + return s.handleError(w, r, f, pid, nil, err) + } if x.IsJSONRequest(r) { s.d.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(codeURL)) } else { @@ -279,7 +276,7 @@ func (s *Strategy) registrationToLogin(w http.ResponseWriter, r *http.Request, r return lf, nil } -func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, rf *registration.Flow, token *oauth2.Token, claims *Claims, provider Provider, container *AuthCodeContainer, idToken string) (*login.Flow, error) { +func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, rf *registration.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider, container *AuthCodeContainer, idToken string) (*login.Flow, error) { if _, _, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject)); err == nil { // If the identity already exists, we should perform the login flow instead. @@ -334,27 +331,7 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r } } - var it string = idToken - var cat, crt string - if token != nil { - if idToken, ok := token.Extra("id_token").(string); ok { - if it, err = s.d.Cipher(r.Context()).Encrypt(r.Context(), []byte(idToken)); err != nil { - return nil, s.handleError(w, r, rf, provider.Config().ID, i.Traits, err) - } - } - - cat, err = s.d.Cipher(r.Context()).Encrypt(r.Context(), []byte(token.AccessToken)) - if err != nil { - return nil, s.handleError(w, r, rf, provider.Config().ID, i.Traits, err) - } - - crt, err = s.d.Cipher(r.Context()).Encrypt(r.Context(), []byte(token.RefreshToken)) - if err != nil { - return nil, s.handleError(w, r, rf, provider.Config().ID, i.Traits, err) - } - } - - creds, err := identity.NewCredentialsOIDC(it, cat, crt, provider.Config().ID, claims.Subject, provider.Config().OrganizationID) + creds, err := identity.NewCredentialsOIDC(token, provider.Config().ID, claims.Subject, provider.Config().OrganizationID) if err != nil { return nil, s.handleError(w, r, rf, provider.Config().ID, i.Traits, err) } diff --git a/selfservice/strategy/oidc/strategy_settings.go b/selfservice/strategy/oidc/strategy_settings.go index ed94972b1500..4fde3a457548 100644 --- a/selfservice/strategy/oidc/strategy_settings.go +++ b/selfservice/strategy/oidc/strategy_settings.go @@ -15,8 +15,6 @@ import ( "github.com/tidwall/sjson" - "golang.org/x/oauth2" - "github.com/ory/kratos/continuity" "github.com/ory/kratos/selfservice/strategy" "github.com/ory/x/decoderx" @@ -367,20 +365,15 @@ func (s *Strategy) initLinkProvider(w http.ResponseWriter, r *http.Request, ctxU return s.handleSettingsError(w, r, ctxUpdate, p, err) } - c, err := provider.OAuth2(r.Context()) - if err != nil { - return s.handleSettingsError(w, r, ctxUpdate, p, err) - } - req, err := s.validateFlow(r.Context(), r, ctxUpdate.Flow.ID) if err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } - state := generateState(ctxUpdate.Flow.ID.String()).String() + state := generateState(ctxUpdate.Flow.ID.String()) if err := s.d.ContinuityManager().Pause(r.Context(), w, r, sessionName, continuity.WithPayload(&AuthCodeContainer{ - State: state, + State: state.String(), FlowID: ctxUpdate.Flow.ID.String(), Traits: p.Traits, }), @@ -393,7 +386,11 @@ func (s *Strategy) initLinkProvider(w http.ResponseWriter, r *http.Request, ctxU return err } - codeURL := c.AuthCodeURL(state, append(UpstreamParameters(provider, up), provider.AuthCodeURLOptions(req)...)...) + codeURL, err := getAuthRedirectURL(r.Context(), provider, req, state, up) + if err != nil { + return s.handleSettingsError(w, r, ctxUpdate, p, err) + } + if x.IsJSONRequest(r) { s.d.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(codeURL)) } else { @@ -403,7 +400,7 @@ func (s *Strategy) initLinkProvider(w http.ResponseWriter, r *http.Request, ctxU return errors.WithStack(flow.ErrCompletedByStrategy) } -func (s *Strategy) linkProvider(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, token *oauth2.Token, claims *Claims, provider Provider) error { +func (s *Strategy) linkProvider(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider) error { p := &updateSettingsFlowWithOidcMethod{ Link: provider.Config().ID, FlowID: ctxUpdate.Flow.ID.String(), } @@ -416,24 +413,7 @@ func (s *Strategy) linkProvider(w http.ResponseWriter, r *http.Request, ctxUpdat return s.handleSettingsError(w, r, ctxUpdate, p, err) } - var it string - if idToken, ok := token.Extra("id_token").(string); ok { - if it, err = s.d.Cipher(r.Context()).Encrypt(r.Context(), []byte(idToken)); err != nil { - return s.handleSettingsError(w, r, ctxUpdate, p, err) - } - } - - cat, err := s.d.Cipher(r.Context()).Encrypt(r.Context(), []byte(token.AccessToken)) - if err != nil { - return s.handleSettingsError(w, r, ctxUpdate, p, err) - } - - crt, err := s.d.Cipher(r.Context()).Encrypt(r.Context(), []byte(token.RefreshToken)) - if err != nil { - return s.handleSettingsError(w, r, ctxUpdate, p, err) - } - - if err := s.linkCredentials(r.Context(), i, it, cat, crt, provider.Config().ID, claims.Subject, provider.Config().OrganizationID); err != nil { + if err := s.linkCredentials(r.Context(), i, token, provider.Config().ID, claims.Subject, provider.Config().OrganizationID); err != nil { return s.handleSettingsError(w, r, ctxUpdate, p, err) } @@ -546,9 +526,8 @@ func (s *Strategy) Link(ctx context.Context, i *identity.Identity, credentialsCo if err := s.linkCredentials( ctx, i, - credentialsOIDCProvider.InitialIDToken, - credentialsOIDCProvider.InitialAccessToken, - credentialsOIDCProvider.InitialRefreshToken, + // The tokens in this credential are coming from the existing identity. Hence, the values are already encrypted. + credentialsOIDCProvider.GetTokens(), credentialsOIDCProvider.Provider, credentialsOIDCProvider.Subject, credentialsOIDCProvider.Organization, diff --git a/test/e2e/cypress/downloads/downloads.html b/test/e2e/cypress/downloads/downloads.html deleted file mode 100644 index 4e6e883a5f39..000000000000 Binary files a/test/e2e/cypress/downloads/downloads.html and /dev/null differ diff --git a/test/e2e/playwright/playwright.env b/test/e2e/playwright/playwright.env new file mode 100644 index 000000000000..8289c433e8a9 --- /dev/null +++ b/test/e2e/playwright/playwright.env @@ -0,0 +1,23 @@ +KRATOS_BROWSER_URL=http://localhost:4433/ +KRATOS_UI_URL=http://localhost:4456/ +KRATOS_UI_REACT_URL=http://localhost:4458/ +KRATOS_UI_REACT_NATIVE_URL=http://localhost:19006/ +KRATOS_ADMIN_URL=http://localhost:4434/ +KRATOS_PUBLIC_URL=http://localhost:4433/ +TEST_DATABASE_MYSQL=mysql://root:secret@(localhost:3444)/mysql?parseTime=true&multiStatements=true +TEST_DATABASE_COCKROACHDB=cockroach://root@localhost:3446/defaultdb?sslmode=disable +TEST_DATABASE_MEMORY=memory +TEST_DATABASE_POSTGRESQL=postgres://postgres:secret@localhost:3445/postgres?sslmode=disable +TEST_DATABASE_SQLITE=sqlite:////var/folders/7v/lfnm0tm91wb6_ngvr0xk5l0h0000gn/T/ci-XXXXXXXXXX.5Bt7rvQC9L/db.sqlite?_fk=true +OIDC_GITHUB_CLIENT_SECRET=cwV-UvqowlDGrxWvU41DvxbsUy +OIDC_HYDRA_CLIENT_ID=d43ab7f7-90a2-4809-876f-8ed6f897d423 +OIDC_GITHUB_CLIENT_ID=05ef7c9a-814c-4cb9-899b-2010328f670a +OIDC_GOOGLE_CLIENT_ID=dc3915c1-0d0a-4376-b4d7-18c45e339882 +CYPRESS_OIDC_DUMMY_CLIENT_ID=75da61f8-e91b-43cb-8b3e-c9f4925194f9 +OIDC_GOOGLE_CLIENT_SECRET=.yT5lLhbOIHiStbOr5xOa-0X1k +OIDC_HYDRA_CLIENT_SECRET=Y9PZnoQ~1lcEo_WCc4.XjeytjG +CYPRESS_OIDC_DUMMY_CLIENT_SECRET=F-H.e6z5_BM9GVIW5y~MqoAw6g +CYPRESS_OIDC_DUMMY_CLIENT_ID=75da61f8-e91b-43cb-8b3e-c9f4925194f9 +CYPRESS_OIDC_DUMMY_CLIENT_SECRET=F-H.e6z5_BM9GVIW5y~MqoAw6g +LOG_LEAK_SENSITIVE_VALUES=true +DEV_DISABLE_API_FLOW_ENFORCEMENT=true