diff --git a/internal/api/api.go b/internal/api/api.go index 6e04b38e4a..ea963e194c 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -153,7 +153,11 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.Get("/", api.UserGet) r.With(sharedLimiter).Put("/", api.UserUpdate) - r.Delete("/identities/{identity_id}", api.DeleteIdentity) + r.Route("/identities", func(r *router) { + r.Use(api.requireManualLinkingEnabled) + r.Get("/authorize", api.LinkIdentity) + r.Delete("/{identity_id}", api.DeleteIdentity) + }) }) r.With(api.requireAuthentication).Route("/factors", func(r *router) { diff --git a/internal/api/context.go b/internal/api/context.go index 107377b50e..2cd73309fc 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -21,6 +21,7 @@ const ( signatureKey = contextKey("signature") externalProviderTypeKey = contextKey("external_provider_type") userKey = contextKey("user") + targetUserKey = contextKey("target_user") factorKey = contextKey("factor") sessionKey = contextKey("session") externalReferrerKey = contextKey("external_referrer") @@ -76,6 +77,11 @@ func withUser(ctx context.Context, u *models.User) context.Context { return context.WithValue(ctx, userKey, u) } +// withTargetUser adds the target user for linking to the context. +func withTargetUser(ctx context.Context, u *models.User) context.Context { + return context.WithValue(ctx, targetUserKey, u) +} + // with Factor adds the factor id to the context. func withFactor(ctx context.Context, f *models.Factor) context.Context { return context.WithValue(ctx, factorKey, f) @@ -93,6 +99,18 @@ func getUser(ctx context.Context) *models.User { return obj.(*models.User) } +// getTargetUser reads the user from the context. +func getTargetUser(ctx context.Context) *models.User { + if ctx == nil { + return nil + } + obj := ctx.Value(targetUserKey) + if obj == nil { + return nil + } + return obj.(*models.User) +} + // getFactor reads the factor id from the context func getFactor(ctx context.Context) *models.Factor { obj := ctx.Value(factorKey) diff --git a/internal/api/external.go b/internal/api/external.go index 30ef2b4616..9f8da6a5c9 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -25,14 +25,25 @@ import ( // ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow type ExternalProviderClaims struct { AuthMicroserviceClaims - Provider string `json:"provider"` - InviteToken string `json:"invite_token,omitempty"` - Referrer string `json:"referrer,omitempty"` - FlowStateID string `json:"flow_state_id"` + Provider string `json:"provider"` + InviteToken string `json:"invite_token,omitempty"` + Referrer string `json:"referrer,omitempty"` + FlowStateID string `json:"flow_state_id"` + LinkingTargetID string `json:"linking_target_id,omitempty"` } -// ExternalProviderRedirect redirects the request to the corresponding oauth provider +// ExternalProviderRedirect redirects the request to the oauth provider func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error { + rurl, err := a.GetExternalProviderRedirectURL(w, r, nil) + if err != nil { + return err + } + http.Redirect(w, r, rurl, http.StatusFound) + return nil +} + +// GetExternalProviderRedirectURL returns the URL to start the oauth flow with the corresponding oauth provider +func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Request, linkingTargetUser *models.User) (string, error) { ctx := r.Context() db := a.db.WithContext(ctx) config := a.config @@ -45,7 +56,7 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e p, err := a.Provider(ctx, providerType, scopes) if err != nil { - return badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return "", badRequestError("Unsupported provider: %+v", err).WithInternalError(err) } inviteToken := query.Get("invite_token") @@ -53,9 +64,9 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e _, userErr := models.FindUserByConfirmationToken(db, inviteToken) if userErr != nil { if models.IsNotFoundError(userErr) { - return notFoundError(userErr.Error()) + return "", notFoundError(userErr.Error()) } - return internalServerError("Database error finding user").WithInternalError(userErr) + return "", internalServerError("Database error finding user").WithInternalError(userErr) } } @@ -63,7 +74,7 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e log := observability.GetLogEntry(r) log.WithField("provider", providerType).Info("Redirecting to external provider") if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil { - return err + return "", err } flowType := getFlowFromChallenge(codeChallenge) @@ -71,19 +82,19 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e if flowType == models.PKCEFlow { codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod) if err != nil { - return err + return "", err } flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth) if err != nil { - return err + return "", err } if err := a.db.Create(flowState); err != nil { - return err + return "", err } flowStateID = flowState.ID.String() } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, ExternalProviderClaims{ + claims := ExternalProviderClaims{ AuthMicroserviceClaims: AuthMicroserviceClaims{ StandardClaims: jwt.StandardClaims{ ExpiresAt: time.Now().Add(5 * time.Minute).Unix(), @@ -95,10 +106,17 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e InviteToken: inviteToken, Referrer: redirectURL, FlowStateID: flowStateID, - }) + } + + if linkingTargetUser != nil { + // this means that the user is performing manual linking + claims.LinkingTargetID = linkingTargetUser.ID.String() + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte(config.JWT.Secret)) if err != nil { - return internalServerError("Error creating state").WithInternalError(err) + return "", internalServerError("Error creating state").WithInternalError(err) } authUrlParams := make([]oauth2.AuthCodeOption, 0) @@ -115,20 +133,15 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e } } - var authURL string + authURL := p.AuthCodeURL(tokenString, authUrlParams...) switch externalProvider := p.(type) { case *provider.TwitterProvider: - authURL = externalProvider.AuthCodeURL(tokenString, authUrlParams...) - err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w) - if err != nil { - return internalServerError("Error storing request token in session").WithInternalError(err) + if err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w); err != nil { + return "", internalServerError("Error storing request token in session").WithInternalError(err) } - default: - authURL = p.AuthCodeURL(tokenString, authUrlParams...) } - http.Redirect(w, r, authURL, http.StatusFound) - return nil + return authURL, nil } // ExternalProviderCallback handles the callback endpoint in the external oauth provider flow @@ -142,37 +155,41 @@ func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) e return nil } +func (a *API) handleOAuthCallback(w http.ResponseWriter, r *http.Request) (*OAuthProviderData, error) { + ctx := r.Context() + providerType := getExternalProviderType(ctx) + + var oAuthResponseData *OAuthProviderData + var err error + switch providerType { + case "twitter": + // future OAuth1.0 providers will use this method + oAuthResponseData, err = a.oAuth1Callback(ctx, r, providerType) + default: + oAuthResponseData, err = a.oAuthCallback(ctx, r, providerType) + } + if err != nil { + return nil, err + } + return oAuthResponseData, nil +} + func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() db := a.db.WithContext(ctx) config := a.config - providerType := getExternalProviderType(ctx) - var userData *provider.UserProvidedData - var providerAccessToken string - var providerRefreshToken string var grantParams models.GrantParams - var err error - grantParams.FillGrantParams(r) - if providerType == "twitter" { - // future OAuth1.0 providers will use this method - oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType) - if err != nil { - return err - } - userData = oAuthResponseData.userData - providerAccessToken = oAuthResponseData.token - } else { - oAuthResponseData, err := a.oAuthCallback(ctx, r, providerType) - if err != nil { - return err - } - userData = oAuthResponseData.userData - providerAccessToken = oAuthResponseData.token - providerRefreshToken = oAuthResponseData.refreshToken + providerType := getExternalProviderType(ctx) + data, err := a.handleOAuthCallback(w, r) + if err != nil { + return err } + userData := data.userData + providerAccessToken := data.token + providerRefreshToken := data.refreshToken var flowState *models.FlowState // if there's a non-empty FlowStateID we perform PKCE Flow @@ -187,8 +204,11 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re var token *AccessTokenResponse err = db.Transaction(func(tx *storage.Connection) error { var terr error - inviteToken := getInviteToken(ctx) - if inviteToken != "" { + if targetUser := getTargetUser(ctx); targetUser != nil { + if user, terr = a.linkIdentityToUser(ctx, tx, userData, providerType); terr != nil { + return terr + } + } else if inviteToken := getInviteToken(ctx); inviteToken != "" { if user, terr = a.processInvite(r, ctx, tx, userData, inviteToken, providerType); terr != nil { return terr } @@ -479,6 +499,20 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont if claims.FlowStateID != "" { ctx = withFlowStateID(ctx, claims.FlowStateID) } + if claims.LinkingTargetID != "" { + linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) + if err != nil { + return nil, badRequestError("invalid target user id") + } + u, err := models.FindUserByID(a.db, linkingTargetUserID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, notFoundError("Linking target user not found") + } + return nil, internalServerError("Database error loading user").WithInternalError(err) + } + ctx = withTargetUser(ctx, u) + } ctx = withExternalProviderType(ctx, claims.Provider) return withSignature(ctx, state), nil } diff --git a/internal/api/identity.go b/internal/api/identity.go index 769dec752b..b3b69faef8 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -1,10 +1,13 @@ package api import ( + "context" "net/http" + "github.com/fatih/structs" "github.com/go-chi/chi" "github.com/gofrs/uuid" + "github.com/supabase/gotrue/internal/api/provider" "github.com/supabase/gotrue/internal/models" "github.com/supabase/gotrue/internal/storage" ) @@ -29,7 +32,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { user := getUser(ctx) if len(user.Identities) <= 1 { - return badRequestError("Cannot unlink identity from user. User must have at least 1 identity after unlinking") + return badRequestError("User must have at least 1 identity after unlinking") } var identityToBeDeleted *models.Identity for i := range user.Identities { @@ -65,3 +68,43 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { return sendJSON(w, http.StatusOK, map[string]interface{}{}) } + +func (a *API) LinkIdentity(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + rurl, err := a.GetExternalProviderRedirectURL(w, r, user) + if err != nil { + return err + } + skipHTTPRedirect := r.URL.Query().Get("skip_http_redirect") == "true" + if skipHTTPRedirect { + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "url": rurl, + }) + } + http.Redirect(w, r, rurl, http.StatusFound) + return nil +} + +func (a *API) linkIdentityToUser(ctx context.Context, tx *storage.Connection, userData *provider.UserProvidedData, providerType string) (*models.User, error) { + targetUser := getTargetUser(ctx) + identity, terr := models.FindIdentityByIdAndProvider(tx, userData.Metadata.Subject, providerType) + if terr != nil { + if !models.IsNotFoundError(terr) { + return nil, internalServerError("Database error finding identity for linking").WithInternalError(terr) + } + } + if identity != nil { + if identity.UserID == targetUser.ID { + return nil, badRequestError("Identity is already linked") + } + return nil, badRequestError("Identity is already linked to another user") + } + if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil { + return nil, terr + } + if terr := targetUser.UpdateAppMetaDataProviders(tx); terr != nil { + return nil, terr + } + return targetUser, nil +} diff --git a/internal/api/identity_test.go b/internal/api/identity_test.go new file mode 100644 index 0000000000..a14b770a70 --- /dev/null +++ b/internal/api/identity_test.go @@ -0,0 +1,77 @@ +package api + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/gotrue/internal/api/provider" + "github.com/supabase/gotrue/internal/conf" + "github.com/supabase/gotrue/internal/models" +) + +type IdentityTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestIdentity(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + ts := &IdentityTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + suite.Run(t, ts) +} + +func (ts *IdentityTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + + // Create identity + i, err := models.NewIdentity(u, "email", map[string]interface{}{ + "sub": u.ID.String(), + "email": u.GetEmail(), + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(i)) +} + +func (ts *IdentityTestSuite) TestLinkIdentityToUser() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + ctx := withTargetUser(context.Background(), u) + + // link a valid identity + testValidUserData := &provider.UserProvidedData{ + Metadata: &provider.Claims{ + Subject: "test_subject", + }, + } + u, err = ts.API.linkIdentityToUser(ctx, ts.API.db, testValidUserData, "test") + require.NoError(ts.T(), err) + + // load associated identities for the user + ts.API.db.Load(u, "Identities") + require.Len(ts.T(), u.Identities, 2) + require.Equal(ts.T(), u.AppMetaData["provider"], "email") + require.Equal(ts.T(), u.AppMetaData["providers"], []string{"email", "test"}) + + // link an already existing identity + testExistingUserData := &provider.UserProvidedData{ + Metadata: &provider.Claims{ + Subject: u.ID.String(), + }, + } + u, err = ts.API.linkIdentityToUser(ctx, ts.API.db, testExistingUserData, "email") + require.ErrorIs(ts.T(), err, badRequestError("Identity is already linked")) + require.Nil(ts.T(), u) +} diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 0ccb51a5a7..d529b89e03 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -233,6 +233,14 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont return ctx, nil } +func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + if !a.config.Security.ManualLinkingEnabled { + return nil, notFoundError("Manual linking is disabled") + } + return ctx, nil +} + func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 4025bd53a7..0e76455f1b 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -386,6 +386,7 @@ type SecurityConfiguration struct { RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"` RefreshTokenReuseInterval int `json:"refresh_token_reuse_interval" split_words:"true"` UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"` + ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"` } func (c *SecurityConfiguration) Validate() error {