diff --git a/server/.golangci.yml b/server/.golangci.yml index d37b229915830..7b70008418241 100644 --- a/server/.golangci.yml +++ b/server/.golangci.yml @@ -81,26 +81,19 @@ issues: channels/api4/file_test.go|\ channels/api4/group.go|\ channels/api4/group_local.go|\ - channels/api4/group_test.go|\ channels/api4/handlers_test.go|\ channels/api4/import_test.go|\ - channels/api4/integration_action.go|\ channels/api4/integration_action_test.go|\ channels/api4/ip_filtering.go|\ channels/api4/ip_filtering_test.go|\ - channels/api4/job.go|\ channels/api4/job_test.go|\ - channels/api4/ldap.go|\ channels/api4/license.go|\ channels/api4/license_local.go|\ channels/api4/license_test.go|\ - channels/api4/oauth.go|\ channels/api4/oauth_test.go|\ channels/api4/outgoing_oauth_connection_test.go|\ - channels/api4/permission.go|\ channels/api4/plugin.go|\ channels/api4/plugin_test.go|\ - channels/api4/post.go|\ channels/api4/post_test.go|\ channels/api4/preference_test.go|\ channels/api4/reaction.go|\ diff --git a/server/channels/api4/group_test.go b/server/channels/api4/group_test.go index 35eaeac27da7a..66f541aa94b89 100644 --- a/server/channels/api4/group_test.go +++ b/server/channels/api4/group_test.go @@ -60,7 +60,8 @@ func TestGetGroup(t *testing.T) { require.Error(t, err) CheckBadRequestStatus(t, response) - th.SystemAdminClient.Logout(context.Background()) + _, err = th.SystemAdminClient.Logout(context.Background()) + require.NoError(t, err) _, response, err = th.SystemAdminClient.GetGroup(context.Background(), group.Id, "") require.Error(t, err) CheckUnauthorizedStatus(t, response) @@ -168,7 +169,8 @@ func TestCreateGroup(t *testing.T) { require.Error(t, err) CheckBadRequestStatus(t, response) - th.SystemAdminClient.Logout(context.Background()) + _, err = th.SystemAdminClient.Logout(context.Background()) + require.NoError(t, err) _, response, err = th.SystemAdminClient.CreateGroup(context.Background(), g) require.Error(t, err) CheckUnauthorizedStatus(t, response) @@ -347,7 +349,8 @@ func TestPatchGroup(t *testing.T) { require.Error(t, err) CheckBadRequestStatus(t, response) - th.SystemAdminClient.Logout(context.Background()) + _, err = th.SystemAdminClient.Logout(context.Background()) + require.NoError(t, err) _, response, err = th.SystemAdminClient.PatchGroup(context.Background(), group.Id, gp) require.Error(t, err) CheckUnauthorizedStatus(t, response) @@ -440,7 +443,8 @@ func TestLinkGroupTeam(t *testing.T) { }) t.Run("System manager without invite_user are allowed to link", func(t *testing.T) { - th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password) + _, _, err = th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password) + require.NoError(t, err) groupSyncable, response, err = th.SystemManagerClient.LinkGroupSyncable(context.Background(), g.Id, th.BasicTeam.Id, model.GroupSyncableTypeTeam, patch) require.NoError(t, err) CheckCreatedStatus(t, response) @@ -564,7 +568,8 @@ func TestLinkGroupChannel(t *testing.T) { }) t.Run("System manager without invite_user are allowed to link", func(t *testing.T) { - th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password) + _, _, err = th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password) + require.NoError(t, err) groupSyncable, response, err = th.SystemManagerClient.LinkGroupSyncable(context.Background(), g.Id, th.BasicChannel.Id, model.GroupSyncableTypeChannel, patch) require.NoError(t, err) CheckCreatedStatus(t, response) @@ -684,7 +689,8 @@ func TestUnlinkGroupTeam(t *testing.T) { }) t.Run("System manager without invite_user are allowed to link", func(t *testing.T) { - th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password) + _, _, err = th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password) + require.NoError(t, err) response, err = th.SystemManagerClient.UnlinkGroupSyncable(context.Background(), g.Id, th.BasicTeam.Id, model.GroupSyncableTypeTeam) require.NoError(t, err) CheckOKStatus(t, response) @@ -802,7 +808,8 @@ func TestUnlinkGroupChannel(t *testing.T) { }) t.Run("System manager without invite_user are allowed to link", func(t *testing.T) { - th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password) + _, _, err = th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password) + require.NoError(t, err) response, err = th.SystemManagerClient.UnlinkGroupSyncable(context.Background(), g.Id, th.BasicChannel.Id, model.GroupSyncableTypeChannel) require.NoError(t, err) CheckOKStatus(t, response) @@ -881,7 +888,8 @@ func TestGetGroupTeam(t *testing.T) { require.Error(t, err) CheckBadRequestStatus(t, response) - th.SystemAdminClient.Logout(context.Background()) + _, err = th.SystemAdminClient.Logout(context.Background()) + require.NoError(t, err) _, response, err = th.SystemAdminClient.GetGroupSyncable(context.Background(), g.Id, th.BasicTeam.Id, model.GroupSyncableTypeTeam, "") require.Error(t, err) CheckUnauthorizedStatus(t, response) @@ -943,7 +951,8 @@ func TestGetGroupChannel(t *testing.T) { require.Error(t, err) CheckBadRequestStatus(t, response) - th.SystemAdminClient.Logout(context.Background()) + _, err = th.SystemAdminClient.Logout(context.Background()) + require.NoError(t, err) _, response, err = th.SystemAdminClient.GetGroupSyncable(context.Background(), g.Id, th.BasicChannel.Id, model.GroupSyncableTypeChannel, "") require.Error(t, err) CheckUnauthorizedStatus(t, response) @@ -996,7 +1005,8 @@ func TestGetGroupTeams(t *testing.T) { assert.Len(t, groupSyncables, 10) - th.SystemAdminClient.Logout(context.Background()) + _, err = th.SystemAdminClient.Logout(context.Background()) + require.NoError(t, err) _, response, err = th.SystemAdminClient.GetGroupSyncables(context.Background(), g.Id, model.GroupSyncableTypeTeam, "") require.Error(t, err) CheckUnauthorizedStatus(t, response) @@ -1048,7 +1058,8 @@ func TestGetGroupChannels(t *testing.T) { assert.Len(t, groupSyncables, 10) - th.SystemAdminClient.Logout(context.Background()) + _, err = th.SystemAdminClient.Logout(context.Background()) + require.NoError(t, err) _, response, err = th.SystemAdminClient.GetGroupSyncables(context.Background(), g.Id, model.GroupSyncableTypeChannel, "") require.Error(t, err) CheckUnauthorizedStatus(t, response) @@ -1120,7 +1131,8 @@ func TestPatchGroupTeam(t *testing.T) { require.Error(t, err) CheckBadRequestStatus(t, response) - th.SystemAdminClient.Logout(context.Background()) + _, err = th.SystemAdminClient.Logout(context.Background()) + require.NoError(t, err) _, response, err = th.SystemAdminClient.PatchGroupSyncable(context.Background(), g.Id, th.BasicTeam.Id, model.GroupSyncableTypeTeam, patch) require.Error(t, err) CheckUnauthorizedStatus(t, response) @@ -1203,7 +1215,8 @@ func TestPatchGroupChannel(t *testing.T) { require.Error(t, err) CheckBadRequestStatus(t, response) - th.SystemAdminClient.Logout(context.Background()) + _, err = th.SystemAdminClient.Logout(context.Background()) + require.NoError(t, err) _, response, err = th.SystemAdminClient.PatchGroupSyncable(context.Background(), g.Id, th.BasicChannel.Id, model.GroupSyncableTypeChannel, patch) require.Error(t, err) CheckUnauthorizedStatus(t, response) @@ -1380,8 +1393,10 @@ func TestGetGroupsAssociatedToChannelsByTeam(t *testing.T) { t.Run("should return forbidden when the user doesn't have the right permissions", func(t *testing.T) { require.Nil(t, th.App.RemoveUserFromTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id)) - defer th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id) - + defer func() { + _, _, appErr := th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id) + require.Nil(t, appErr) + }() groups, resp, err := th.Client.GetGroupsAssociatedToChannelsByTeam(context.Background(), th.BasicTeam.Id, opts) require.Error(t, err) CheckForbiddenStatus(t, resp) @@ -1426,7 +1441,8 @@ func TestGetGroupsByTeam(t *testing.T) { CheckBadRequestStatus(t, response) }) - th.App.Srv().RemoveLicense() + appErr := th.App.Srv().RemoveLicense() + require.Nil(t, appErr) th.TestForSystemAdminAndLocal(t, func(t *testing.T, client *model.Client4) { _, _, response, err := client.GetGroupsByTeam(context.Background(), th.BasicTeam.Id, opts) @@ -1487,7 +1503,10 @@ func TestGetGroupsByTeam(t *testing.T) { t.Run("user can't fetch groups if it's not part of the team", func(t *testing.T) { require.Nil(t, th.App.RemoveUserFromTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id)) - defer th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id) + defer func() { + _, _, appErr := th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id) + require.Nil(t, appErr) + }() groups, _, response, err := th.Client.GetGroupsByTeam(context.Background(), th.BasicTeam.Id, opts) require.Error(t, err) @@ -1580,7 +1599,8 @@ func TestGetGroups(t *testing.T) { assert.Equal(t, groups[0].Id, group.Id) // delete group, should still return - th.App.DeleteGroup(group.Id) + _, appErr = th.App.DeleteGroup(group.Id) + require.Nil(t, appErr) groups, _, err = th.Client.GetGroups(context.Background(), opts) assert.NoError(t, err) assert.Len(t, groups, 1) @@ -1721,14 +1741,18 @@ func TestGetGroupsByUserId(t *testing.T) { assert.ElementsMatch(t, []*model.Group{group1, group2}, groups) // test permissions - th.Client.Logout(context.Background()) - th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password) + _, err = th.Client.Logout(context.Background()) + require.NoError(t, err) + _, _, err = th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password) + require.NoError(t, err) _, response, err = th.Client.GetGroupsByUserId(context.Background(), user1.Id) require.Error(t, err) CheckForbiddenStatus(t, response) - th.Client.Logout(context.Background()) - th.Client.Login(context.Background(), user1.Email, user1.Password) + _, err = th.Client.Logout(context.Background()) + require.NoError(t, err) + _, _, err = th.Client.Login(context.Background(), user1.Email, user1.Password) + require.NoError(t, err) groups, _, err = th.Client.GetGroupsByUserId(context.Background(), user1.Id) require.NoError(t, err) assert.ElementsMatch(t, []*model.Group{group1, group2}, groups) @@ -1823,7 +1847,8 @@ func TestGetGroupStats(t *testing.T) { th.App.Srv().SetLicense(model.NewTestLicense("ldap")) t.Run("Requires manage system permission to access group stats", func(t *testing.T) { - th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password) + _, _, err := th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password) + require.NoError(t, err) _, response, err := th.Client.GetGroupStats(context.Background(), group.Id) require.Error(t, err) CheckForbiddenStatus(t, response) diff --git a/server/channels/api4/integration_action.go b/server/channels/api4/integration_action.go index eae516efabc8c..7076009b81c45 100644 --- a/server/channels/api4/integration_action.go +++ b/server/channels/api4/integration_action.go @@ -136,5 +136,7 @@ func submitDialog(c *Context, w http.ResponseWriter, r *http.Request) { b, _ := json.Marshal(resp) - w.Write(b) + if _, err := w.Write(b); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } } diff --git a/server/channels/api4/job.go b/server/channels/api4/job.go index e577c7943c824..a794b1874f46a 100644 --- a/server/channels/api4/job.go +++ b/server/channels/api4/job.go @@ -210,7 +210,9 @@ func getJobs(c *Context, w http.ResponseWriter, r *http.Request) { c.Err = model.NewAppError("getJobs", "api.marshal_error", nil, "", http.StatusInternalServerError).Wrap(err) return } - w.Write(js) + if _, err := w.Write(js); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } } func getJobsByType(c *Context, w http.ResponseWriter, r *http.Request) { @@ -241,7 +243,9 @@ func getJobsByType(c *Context, w http.ResponseWriter, r *http.Request) { return } - w.Write(js) + if _, err := w.Write(js); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } } func cancelJob(c *Context, w http.ResponseWriter, r *http.Request) { diff --git a/server/channels/api4/ldap.go b/server/channels/api4/ldap.go index 0669332328089..8d2d03b86014b 100644 --- a/server/channels/api4/ldap.go +++ b/server/channels/api4/ldap.go @@ -140,7 +140,9 @@ func getLdapGroups(c *Context, w http.ResponseWriter, r *http.Request) { return } - w.Write(b) + if _, err := w.Write(b); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } } func linkLdapGroup(c *Context, w http.ResponseWriter, r *http.Request) { @@ -240,7 +242,9 @@ func linkLdapGroup(c *Context, w http.ResponseWriter, r *http.Request) { auditRec.Success() w.WriteHeader(status) - w.Write(b) + if _, err := w.Write(b); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } } func unlinkLdapGroup(c *Context, w http.ResponseWriter, r *http.Request) { diff --git a/server/channels/api4/oauth.go b/server/channels/api4/oauth.go index a9f108100cb19..d65b73f23c57d 100644 --- a/server/channels/api4/oauth.go +++ b/server/channels/api4/oauth.go @@ -153,7 +153,9 @@ func getOAuthApps(c *Context, w http.ResponseWriter, r *http.Request) { return } - w.Write(js) + if _, err := w.Write(js); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } } func getOAuthApp(c *Context, w http.ResponseWriter, r *http.Request) { @@ -308,5 +310,7 @@ func getAuthorizedOAuthApps(c *Context, w http.ResponseWriter, r *http.Request) return } - w.Write(js) + if _, err := w.Write(js); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } } diff --git a/server/channels/api4/permission.go b/server/channels/api4/permission.go index 2489aee2f0280..2566dff461c85 100644 --- a/server/channels/api4/permission.go +++ b/server/channels/api4/permission.go @@ -8,6 +8,7 @@ import ( "net/http" "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/mlog" ) func (api *API) InitPermissions() { @@ -25,5 +26,7 @@ func appendAncillaryPermissionsPost(c *Context, w http.ResponseWriter, r *http.R c.SetJSONEncodingError(err) return } - w.Write(b) + if _, err := w.Write(b); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } } diff --git a/server/channels/api4/post.go b/server/channels/api4/post.go index d41f58e3f2142..559fec1c4df4a 100644 --- a/server/channels/api4/post.go +++ b/server/channels/api4/post.go @@ -1127,7 +1127,9 @@ func acknowledgePost(c *Context, w http.ResponseWriter, r *http.Request) { return } - w.Write(js) + if _, err := w.Write(js); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } } func unacknowledgePost(c *Context, w http.ResponseWriter, r *http.Request) { @@ -1284,7 +1286,9 @@ func getFileInfosForPost(c *Context, w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "max-age=2592000, private") w.Header().Set(model.HeaderEtagServer, model.GetEtagForFileInfos(infos)) - w.Write(js) + if _, err := w.Write(js); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } } func getPostInfo(c *Context, w http.ResponseWriter, r *http.Request) { @@ -1305,7 +1309,9 @@ func getPostInfo(c *Context, w http.ResponseWriter, r *http.Request) { return } - w.Write(js) + if _, err := w.Write(js); err != nil { + c.Logger.Warn("Error while writing response", mlog.Err(err)) + } } func hasPermittedWranglerRole(c *Context, user *model.User, channelMember *model.ChannelMember) bool { diff --git a/server/channels/api4/status_test.go b/server/channels/api4/status_test.go index 418a4b781962d..62a7c4ee649c7 100644 --- a/server/channels/api4/status_test.go +++ b/server/channels/api4/status_test.go @@ -5,13 +5,13 @@ package api4 import ( "context" + "strings" "testing" "time" + "github.com/mattermost/mattermost/server/public/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/mattermost/mattermost/server/public/model" ) func TestGetUserStatus(t *testing.T) { @@ -243,3 +243,175 @@ func TestUpdateUserStatus(t *testing.T) { CheckUnauthorizedStatus(t, resp) }) } + +func TestUpdateUserCustomStatus(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + client := th.Client + + t.Run("set custom status", func(t *testing.T) { + toUpdateCustomStatus := &model.CustomStatus{ + Emoji: "calendar", // Use a valid emoji name + Text: "My custom status", + } + _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus) + require.NoError(t, err) + CheckOKStatus(t, resp) + + user, _, err := client.GetUser(context.Background(), th.BasicUser.Id, "") + require.NoError(t, err) + customStatus := user.GetCustomStatus() + require.NotNil(t, customStatus) + assert.Equal(t, toUpdateCustomStatus.Emoji, customStatus.Emoji) + assert.Equal(t, toUpdateCustomStatus.Text, customStatus.Text) + }) + + t.Run("update custom status with duration", func(t *testing.T) { + expiresAt := time.Now().Add(1 * time.Hour) + toUpdateCustomStatus := &model.CustomStatus{ + Emoji: "palm_tree", // Use a valid emoji name + Text: "On vacation", + Duration: "date_and_time", + ExpiresAt: expiresAt, + } + _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus) + require.NoError(t, err) + CheckOKStatus(t, resp) + + user, _, err := client.GetUser(context.Background(), th.BasicUser.Id, "") + require.NoError(t, err) + customStatus := user.GetCustomStatus() + require.NotNil(t, customStatus) + assert.Equal(t, toUpdateCustomStatus.Emoji, customStatus.Emoji) + assert.Equal(t, toUpdateCustomStatus.Text, customStatus.Text) + assert.Equal(t, toUpdateCustomStatus.Duration, customStatus.Duration) + + require.NotNil(t, customStatus.ExpiresAt, "Expected ExpiresAt to be set") + // Check that ExpiresAt is within 5 seconds of the expected time + assert.WithinDuration(t, expiresAt, customStatus.ExpiresAt, 5*time.Second) + }) + + t.Run("attempt to set custom status when disabled", func(t *testing.T) { + th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.EnableCustomUserStatuses = false }) + defer th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.EnableCustomUserStatuses = true }) + + toUpdateCustomStatus := &model.CustomStatus{ + Emoji: "palm_tree", + Text: "My custom status", + } + _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus) + require.Error(t, err) + CheckNotImplementedStatus(t, resp) + + // Assert that the error ID is "api.custom_status.disabled" + if appErr, ok := err.(*model.AppError); ok { + assert.Equal(t, "api.custom_status.disabled", appErr.Id) + } else { + t.Errorf("expected *model.AppError, got %T", err) + } + }) + + t.Run("attempt to set custom status for another user", func(t *testing.T) { + toUpdateCustomStatus := &model.CustomStatus{ + Emoji: "palm_tree", + Text: "My custom status", + } + _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser2.Id, toUpdateCustomStatus) + require.Error(t, err) + CheckForbiddenStatus(t, resp) + }) + + t.Run("attempt to set custom status with invalid data", func(t *testing.T) { + toUpdateCustomStatus := &model.CustomStatus{ + Emoji: "invalid_emoji", + Text: strings.Repeat("a", 101), // Exceeds max length + Duration: "invalid_duration", + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus) + require.Error(t, err) + CheckBadRequestStatus(t, resp) + }) + + t.Run("attempt to set custom status as non-authenticated user", func(t *testing.T) { + client.Logout(context.Background()) + toUpdateCustomStatus := &model.CustomStatus{ + Emoji: "palm_tree", + Text: "My custom status", + } + _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus) + require.Error(t, err) + CheckUnauthorizedStatus(t, resp) + }) +} + +func TestRemoveUserCustomStatus(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + client := th.Client + + t.Run("remove custom status successfully", func(t *testing.T) { + toUpdateCustomStatus := &model.CustomStatus{ + Emoji: "calendar", + Text: "My custom status", + } + _, _, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus) + require.NoError(t, err) + + resp, err := client.RemoveUserCustomStatus(context.Background(), th.BasicUser.Id) + require.NoError(t, err) + CheckOKStatus(t, resp) + + user, _, err := client.GetUser(context.Background(), th.BasicUser.Id, "") + require.NoError(t, err) + customStatus := user.GetCustomStatus() + assert.Nil(t, customStatus) + }) + + t.Run("attempt to remove custom status when disabled", func(t *testing.T) { + th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.EnableCustomUserStatuses = false }) + defer th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.EnableCustomUserStatuses = true }) + + resp, err := client.RemoveUserCustomStatus(context.Background(), th.BasicUser.Id) + require.Error(t, err) + CheckNotImplementedStatus(t, resp) + }) + + t.Run("attempt to remove custom status for another user", func(t *testing.T) { + resp, err := client.RemoveUserCustomStatus(context.Background(), th.BasicUser2.Id) + require.Error(t, err) + CheckForbiddenStatus(t, resp) + }) + + t.Run("attempt to remove custom status as non-authenticated user", func(t *testing.T) { + client.Logout(context.Background()) + resp, err := client.RemoveUserCustomStatus(context.Background(), th.BasicUser.Id) + require.Error(t, err) + CheckUnauthorizedStatus(t, resp) + }) + + t.Run("remove non-existent custom status", func(t *testing.T) { + th.LoginBasic() + resp, err := client.RemoveUserCustomStatus(context.Background(), th.BasicUser.Id) + require.NoError(t, err) + CheckOKStatus(t, resp) + }) + + t.Run("remove custom status with system admin", func(t *testing.T) { + toUpdateCustomStatus := &model.CustomStatus{ + Emoji: "calendar", + Text: "My custom status", + } + _, _, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus) + require.NoError(t, err) + + resp, err := th.SystemAdminClient.RemoveUserCustomStatus(context.Background(), th.BasicUser.Id) + require.NoError(t, err) + CheckOKStatus(t, resp) + + user, _, err := client.GetUser(context.Background(), th.BasicUser.Id, "") + require.NoError(t, err) + customStatus := user.GetCustomStatus() + assert.Nil(t, customStatus) + }) +} diff --git a/server/channels/app/channel.go b/server/channels/app/channel.go index 3a763e7633112..8664c33c2e2be 100644 --- a/server/channels/app/channel.go +++ b/server/channels/app/channel.go @@ -1298,7 +1298,7 @@ func (a *App) UpdateChannelMemberNotifyProps(c request.CTX, data map[string]stri a.invalidateCacheForChannelMembersNotifyProps(member.ChannelId) // Notify the clients that the member notify props changed - err = a.sendUpdateChannelMemberNotifyPropsEvent(member) + err = a.sendUpdateChannelMemberEvent(member) if err != nil { return nil, model.NewAppError("UpdateChannelMemberNotifyProps", "api.marshal_error", nil, "", http.StatusInternalServerError).Wrap(err) } @@ -1339,7 +1339,7 @@ func (a *App) PatchChannelMembersNotifyProps(c request.CTX, members []*model.Cha // Notify clients that their notify props have changed for _, member := range updated { - err := a.sendUpdateChannelMemberNotifyPropsEvent(member) + err := a.sendUpdateChannelMemberEvent(member) if err != nil { c.Logger().Warn("Failed to send WebSocket event for updated channel member notify props", mlog.Err(err)) } @@ -1348,7 +1348,7 @@ func (a *App) PatchChannelMembersNotifyProps(c request.CTX, members []*model.Cha return updated, nil } -func (a *App) sendUpdateChannelMemberNotifyPropsEvent(member *model.ChannelMember) error { +func (a *App) sendUpdateChannelMemberEvent(member *model.ChannelMember) error { evt := model.NewWebSocketEvent(model.WebsocketEventChannelMemberUpdated, "", "", member.UserId, nil, "") memberJSON, jsonErr := json.Marshal(member) if jsonErr != nil { @@ -1962,9 +1962,11 @@ func (a *App) GetAllChannels(c request.CTX, page, perPage int, opts model.Channe opts.ExcludeChannelNames = a.DefaultChannelNames(c) } storeOpts := store.ChannelSearchOpts{ - ExcludeChannelNames: opts.ExcludeChannelNames, NotAssociatedToGroup: opts.NotAssociatedToGroup, IncludeDeleted: opts.IncludeDeleted, + ExcludeChannelNames: opts.ExcludeChannelNames, + GroupConstrained: opts.GroupConstrained, + ExcludeGroupConstrained: opts.ExcludeGroupConstrained, ExcludePolicyConstrained: opts.ExcludePolicyConstrained, IncludePolicyID: opts.IncludePolicyID, } @@ -1981,9 +1983,13 @@ func (a *App) GetAllChannelsCount(c request.CTX, opts model.ChannelSearchOpts) ( opts.ExcludeChannelNames = a.DefaultChannelNames(c) } storeOpts := store.ChannelSearchOpts{ - ExcludeChannelNames: opts.ExcludeChannelNames, - NotAssociatedToGroup: opts.NotAssociatedToGroup, - IncludeDeleted: opts.IncludeDeleted, + NotAssociatedToGroup: opts.NotAssociatedToGroup, + IncludeDeleted: opts.IncludeDeleted, + ExcludeChannelNames: opts.ExcludeChannelNames, + GroupConstrained: opts.GroupConstrained, + ExcludeGroupConstrained: opts.ExcludeGroupConstrained, + ExcludePolicyConstrained: opts.ExcludePolicyConstrained, + IncludePolicyID: opts.IncludePolicyID, } count, err := a.Srv().Store().Channel().GetAllChannelsCount(storeOpts) if err != nil { diff --git a/server/channels/app/syncables.go b/server/channels/app/syncables.go index 45d293cd9830e..d4c3cdeaa5ba5 100644 --- a/server/channels/app/syncables.go +++ b/server/channels/app/syncables.go @@ -234,26 +234,47 @@ func (a *App) SyncSyncableRoles(rctx request.CTX, syncableID string, syncableTyp switch syncableType { case model.GroupSyncableTypeTeam: - nErr := a.Srv().Store().Team().UpdateMembersRole(syncableID, permittedAdmins) - if nErr != nil { - return model.NewAppError("App.SyncSyncableRoles", "app.update_error", nil, "", http.StatusInternalServerError).Wrap(nErr) + var updatedMembers []*model.TeamMember + updatedMembers, err = a.Srv().Store().Team().UpdateMembersRole(syncableID, permittedAdmins) + if err != nil { + return model.NewAppError("App.SyncSyncableRoles", "app.update_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + for _, member := range updatedMembers { + a.ClearSessionCacheForUser(member.UserId) + + if appErr := a.sendUpdatedTeamMemberEvent(member); appErr != nil { + rctx.Logger().Warn("Error sending channel member updated websocket event", mlog.Err(appErr)) + } } - return nil case model.GroupSyncableTypeChannel: - nErr := a.Srv().Store().Channel().UpdateMembersRole(syncableID, permittedAdmins) - if nErr != nil { - return model.NewAppError("App.SyncSyncableRoles", "app.update_error", nil, "", http.StatusInternalServerError).Wrap(nErr) + var updatedMembers []*model.ChannelMember + updatedMembers, err = a.Srv().Store().Channel().UpdateMembersRole(syncableID, permittedAdmins) + if err != nil { + return model.NewAppError("App.SyncSyncableRoles", "app.update_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + for _, member := range updatedMembers { + a.ClearSessionCacheForUser(member.UserId) + + if appErr := a.sendUpdateChannelMemberEvent(member); appErr != nil { + rctx.Logger().Warn("Error sending channel member updated websocket event", mlog.Err(appErr)) + } } - return nil default: return model.NewAppError("App.SyncSyncableRoles", "groups.unsupported_syncable_type", map[string]any{"Value": syncableType}, "", http.StatusInternalServerError) } + + return nil } // SyncRolesAndMembership updates the SchemeAdmin status and membership of all of the members of the given // syncable. func (a *App) SyncRolesAndMembership(rctx request.CTX, syncableID string, syncableType model.GroupSyncableType, includeRemovedMembers bool) { - a.SyncSyncableRoles(rctx, syncableID, syncableType) + appErr := a.SyncSyncableRoles(rctx, syncableID, syncableType) + if appErr != nil { + rctx.Logger().Warn("Error syncing syncable roles", mlog.Err(appErr)) + } lastJob, _ := a.Srv().Store().Job().GetNewestJobByStatusAndType(model.JobStatusSuccess, model.JobTypeLdapSync) var since int64 @@ -272,9 +293,6 @@ func (a *App) SyncRolesAndMembership(rctx request.CTX, syncableID string, syncab if err := a.deleteGroupConstrainedTeamMemberships(rctx, &syncableID); err != nil { rctx.Logger().Warn("Error deleting group constrained team memberships", mlog.Err(err)) } - if err := a.ClearTeamMembersCache(syncableID); err != nil { - rctx.Logger().Warn("Error clearing team members cache", mlog.Err(err)) - } case model.GroupSyncableTypeChannel: params.ScopedChannelID = &syncableID if err := a.createDefaultChannelMemberships(rctx, params); err != nil { @@ -283,8 +301,5 @@ func (a *App) SyncRolesAndMembership(rctx request.CTX, syncableID string, syncab if err := a.deleteGroupConstrainedChannelMemberships(rctx, &syncableID); err != nil { rctx.Logger().Warn("Error deleting group constrained team memberships", mlog.Err(err)) } - if err := a.ClearChannelMembersCache(rctx, syncableID); err != nil { - rctx.Logger().Warn("Error clearing channel members cache", mlog.Err(err)) - } } } diff --git a/server/channels/app/team.go b/server/channels/app/team.go index 415965fe6ec77..72ce428660f82 100644 --- a/server/channels/app/team.go +++ b/server/channels/app/team.go @@ -469,7 +469,7 @@ func (a *App) UpdateTeamMemberRoles(c request.CTX, teamID string, userID string, a.ClearSessionCacheForUser(userID) - if appErr := a.sendUpdatedMemberRoleEvent(userID, member); appErr != nil { + if appErr := a.sendUpdatedTeamMemberEvent(member); appErr != nil { return nil, appErr } @@ -512,15 +512,15 @@ func (a *App) UpdateTeamMemberSchemeRoles(c request.CTX, teamID string, userID s a.ClearSessionCacheForUser(userID) - if appErr := a.sendUpdatedMemberRoleEvent(userID, member); appErr != nil { + if appErr := a.sendUpdatedTeamMemberEvent(member); appErr != nil { return nil, appErr } return member, nil } -func (a *App) sendUpdatedMemberRoleEvent(userID string, member *model.TeamMember) *model.AppError { - message := model.NewWebSocketEvent(model.WebsocketEventMemberroleUpdated, "", "", userID, nil, "") +func (a *App) sendUpdatedTeamMemberEvent(member *model.TeamMember) *model.AppError { + message := model.NewWebSocketEvent(model.WebsocketEventMemberroleUpdated, "", "", member.UserId, nil, "") tmJSON, jsonErr := json.Marshal(member) if jsonErr != nil { return model.NewAppError("sendUpdatedMemberRoleEvent", "api.marshal_error", nil, "", http.StatusInternalServerError).Wrap(jsonErr) diff --git a/server/channels/app/user.go b/server/channels/app/user.go index 62f9d0004e0ad..c17155e70094b 100644 --- a/server/channels/app/user.go +++ b/server/channels/app/user.go @@ -2443,7 +2443,7 @@ func (a *App) PromoteGuestToUser(c request.CTX, user *model.User, requestorId st } for _, member := range teamMembers { - a.sendUpdatedMemberRoleEvent(user.Id, member) + a.sendUpdatedTeamMemberEvent(member) channelMembers, appErr := a.GetChannelMembersForUser(c, member.TeamId, user.Id) if appErr != nil { @@ -2487,7 +2487,7 @@ func (a *App) DemoteUserToGuest(c request.CTX, user *model.User) *model.AppError } for _, member := range teamMembers { - a.sendUpdatedMemberRoleEvent(user.Id, member) + a.sendUpdatedTeamMemberEvent(member) channelMembers, appErr := a.GetChannelMembersForUser(c, member.TeamId, user.Id) if appErr != nil { diff --git a/server/channels/store/opentracinglayer/opentracinglayer.go b/server/channels/store/opentracinglayer/opentracinglayer.go index ec5ff4791ee6f..e05d3ac0cbb86 100644 --- a/server/channels/store/opentracinglayer/opentracinglayer.go +++ b/server/channels/store/opentracinglayer/opentracinglayer.go @@ -2563,7 +2563,7 @@ func (s *OpenTracingLayerChannelStore) UpdateMemberNotifyProps(channelID string, return result, err } -func (s *OpenTracingLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) error { +func (s *OpenTracingLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) ([]*model.ChannelMember, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.UpdateMembersRole") s.Root.Store.SetContext(newCtx) @@ -2572,13 +2572,13 @@ func (s *OpenTracingLayerChannelStore) UpdateMembersRole(channelID string, userI }() defer span.Finish() - err := s.ChannelStore.UpdateMembersRole(channelID, userIDs) + result, err := s.ChannelStore.UpdateMembersRole(channelID, userIDs) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) } - return err + return result, err } func (s *OpenTracingLayerChannelStore) UpdateMultipleMembers(members []*model.ChannelMember) ([]*model.ChannelMember, error) { @@ -10554,7 +10554,7 @@ func (s *OpenTracingLayerTeamStore) UpdateMember(rctx request.CTX, member *model return result, err } -func (s *OpenTracingLayerTeamStore) UpdateMembersRole(teamID string, userIDs []string) error { +func (s *OpenTracingLayerTeamStore) UpdateMembersRole(teamID string, adminIDs []string) ([]*model.TeamMember, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "TeamStore.UpdateMembersRole") s.Root.Store.SetContext(newCtx) @@ -10563,13 +10563,13 @@ func (s *OpenTracingLayerTeamStore) UpdateMembersRole(teamID string, userIDs []s }() defer span.Finish() - err := s.TeamStore.UpdateMembersRole(teamID, userIDs) + result, err := s.TeamStore.UpdateMembersRole(teamID, adminIDs) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) } - return err + return result, err } func (s *OpenTracingLayerTeamStore) UpdateMultipleMembers(members []*model.TeamMember) ([]*model.TeamMember, error) { diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index e04b8c5a36e2b..3db4b2d39ef4a 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -2840,21 +2840,21 @@ func (s *RetryLayerChannelStore) UpdateMemberNotifyProps(channelID string, userI } -func (s *RetryLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) error { +func (s *RetryLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) ([]*model.ChannelMember, error) { tries := 0 for { - err := s.ChannelStore.UpdateMembersRole(channelID, userIDs) + result, err := s.ChannelStore.UpdateMembersRole(channelID, userIDs) if err == nil { - return nil + return result, nil } if !isRepeatableError(err) { - return err + return result, err } tries++ if tries >= 3 { err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") - return err + return result, err } timepkg.Sleep(100 * timepkg.Millisecond) } @@ -12071,21 +12071,21 @@ func (s *RetryLayerTeamStore) UpdateMember(rctx request.CTX, member *model.TeamM } -func (s *RetryLayerTeamStore) UpdateMembersRole(teamID string, userIDs []string) error { +func (s *RetryLayerTeamStore) UpdateMembersRole(teamID string, adminIDs []string) ([]*model.TeamMember, error) { tries := 0 for { - err := s.TeamStore.UpdateMembersRole(teamID, userIDs) + result, err := s.TeamStore.UpdateMembersRole(teamID, adminIDs) if err == nil { - return nil + return result, nil } if !isRepeatableError(err) { - return err + return result, err } tries++ if tries >= 3 { err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") - return err + return result, err } timepkg.Sleep(100 * timepkg.Millisecond) } diff --git a/server/channels/store/sqlstore/channel_store.go b/server/channels/store/sqlstore/channel_store.go index dce8d6ddcfc11..ec81005e3836a 100644 --- a/server/channels/store/sqlstore/channel_store.go +++ b/server/channels/store/sqlstore/channel_store.go @@ -7,6 +7,7 @@ import ( "context" "database/sql" "fmt" + "slices" "sort" "strconv" "strings" @@ -1190,6 +1191,15 @@ func (s SqlChannelStore) getAllChannelsQuery(opts store.ChannelSearchOpts, forCo query = query.Where("c.Id NOT IN (SELECT ChannelId FROM GroupChannels WHERE GroupChannels.GroupId = ? AND GroupChannels.DeleteAt = 0)", opts.NotAssociatedToGroup) } + if opts.GroupConstrained { + query = query.Where(sq.Eq{"c.GroupConstrained": true}) + } else if opts.ExcludeGroupConstrained { + query = query.Where(sq.Or{ + sq.NotEq{"c.GroupConstrained": true}, + sq.Eq{"c.GroupConstrained": nil}, + }) + } + if len(opts.ExcludeChannelNames) > 0 { query = query.Where(sq.NotEq{"c.Name": opts.ExcludeChannelNames}) } @@ -4161,27 +4171,76 @@ func (s SqlChannelStore) UserBelongsToChannels(userId string, channelIds []strin return c > 0, nil } -// TODO: parameterize userIDs -func (s SqlChannelStore) UpdateMembersRole(channelID string, userIDs []string) error { - sql := fmt.Sprintf(` - UPDATE - ChannelMembers - SET - SchemeAdmin = CASE WHEN UserId IN ('%s') THEN - TRUE - ELSE - FALSE - END - WHERE - ChannelId = ? - AND (SchemeGuest = false OR SchemeGuest IS NULL) - `, strings.Join(userIDs, "', '")) +// UpdateMembersRole updates all the members of channelID in the adminIDs string array to be admins and sets all other +// users as not being admin. +// It returns the list of userIDs whose roles got updated. +// +// TODO: parameterize adminIDs +func (s SqlChannelStore) UpdateMembersRole(channelID string, adminIDs []string) (_ []*model.ChannelMember, err error) { + transaction, err := s.GetMasterX().Beginx() + if err != nil { + return nil, err + } + defer finalizeTransactionX(transaction, &err) - if _, err := s.GetMasterX().Exec(sql, channelID); err != nil { - return errors.Wrap(err, "failed to update ChannelMembers") + // On MySQL it's not possible to update a table and select from it in the same query. + // A SELECT and a UPDATE query are needed. + // Once we only support PostgreSQL, this can be done in a single query using RETURNING. + query, args, err := s.getQueryBuilder(). + Select("*"). + From("ChannelMembers"). + Where(sq.Eq{"ChannelID": channelID}). + Where(sq.Or{sq.Eq{"SchemeGuest": false}, sq.Expr("SchemeGuest IS NULL")}). + Where( + sq.Or{ + // New admins + sq.And{ + sq.Eq{"SchemeAdmin": false}, + sq.Eq{"UserId": adminIDs}, + }, + // Demoted admins + sq.And{ + sq.Eq{"SchemeAdmin": true}, + sq.NotEq{"UserId": adminIDs}, + }, + }, + ).ToSql() + if err != nil { + return nil, errors.Wrap(err, "channel_tosql") } - return nil + var updatedMembers []*model.ChannelMember + if err = transaction.Select(&updatedMembers, query, args...); err != nil { + return nil, errors.Wrap(err, "failed to get list of updated users") + } + + // Update SchemeAdmin field as the data from the SQL is not updated yet + for _, member := range updatedMembers { + if slices.Contains(adminIDs, member.UserId) { + member.SchemeAdmin = true + } else { + member.SchemeAdmin = false + } + } + + query, args, err = s.getQueryBuilder(). + Update("ChannelMembers"). + Set("SchemeAdmin", sq.Case().When(sq.Eq{"UserId": adminIDs}, "true").Else("false")). + Where(sq.Eq{"ChannelId": channelID}). + Where(sq.Or{sq.Eq{"SchemeGuest": false}, sq.Expr("SchemeGuest IS NULL")}).ToSql() + if err != nil { + return nil, errors.Wrap(err, "team_tosql") + } + + if _, err = transaction.Exec(query, args...); err != nil { + return nil, errors.Wrap(err, "failed to update ChannelMembers") + } + + if err = transaction.Commit(); err != nil { + return nil, errors.Wrap(err, "commit_transaction") + } + + return updatedMembers, nil } func (s SqlChannelStore) GroupSyncedChannelCount() (int64, error) { diff --git a/server/channels/store/sqlstore/team_store.go b/server/channels/store/sqlstore/team_store.go index 92674434c1cc3..f564fd1e62730 100644 --- a/server/channels/store/sqlstore/team_store.go +++ b/server/channels/store/sqlstore/team_store.go @@ -6,6 +6,7 @@ package sqlstore import ( "database/sql" "fmt" + "slices" "strings" sq "github.com/mattermost/squirrel" @@ -1591,23 +1592,74 @@ func (s SqlTeamStore) UserBelongsToTeams(userId string, teamIds []string) (bool, return c > 0, nil } -// UpdateMembersRole updates all the members of teamID in the userIds string array to be admins and sets all other +// UpdateMembersRole updates all the members of teamID in the adminIDs string array to be admins and sets all other // users as not being admin. -func (s SqlTeamStore) UpdateMembersRole(teamID string, userIDs []string) error { +// It returns the list of userIDs whose roles got updated. +func (s SqlTeamStore) UpdateMembersRole(teamID string, adminIDs []string) (_ []*model.TeamMember, err error) { + transaction, err := s.GetMasterX().Beginx() + if err != nil { + return nil, err + } + defer finalizeTransactionX(transaction, &err) + + // On MySQL it's not possible to update a table and select from it in the same query. + // A SELECT and a UPDATE query are needed. + // Once we only support PostgreSQL, this can be done in a single query using RETURNING. query, args, err := s.getQueryBuilder(). + Select("*"). + From("TeamMembers"). + Where(sq.Eq{"TeamId": teamID, "DeleteAt": 0}). + Where(sq.Or{sq.Eq{"SchemeGuest": false}, sq.Expr("SchemeGuest IS NULL")}). + Where( + sq.Or{ + // New admins + sq.And{ + sq.Eq{"SchemeAdmin": false}, + sq.Eq{"UserId": adminIDs}, + }, + // Demoted admins + sq.And{ + sq.Eq{"SchemeAdmin": true}, + sq.NotEq{"UserId": adminIDs}, + }, + }, + ).ToSql() + if err != nil { + return nil, errors.Wrap(err, "team_tosql") + } + + var updatedMembers []*model.TeamMember + if err = transaction.Select(&updatedMembers, query, args...); err != nil { + return nil, errors.Wrap(err, "failed to get list of updated users") + } + + // Update SchemeAdmin field as the data from the SQL is not updated yet + for _, member := range updatedMembers { + if slices.Contains(adminIDs, member.UserId) { + member.SchemeAdmin = true + } else { + member.SchemeAdmin = false + } + } + + query, args, err = s.getQueryBuilder(). Update("TeamMembers"). - Set("SchemeAdmin", sq.Case().When(sq.Eq{"UserId": userIDs}, "true").Else("false")). + Set("SchemeAdmin", sq.Case().When(sq.Eq{"UserId": adminIDs}, "true").Else("false")). Where(sq.Eq{"TeamId": teamID, "DeleteAt": 0}). Where(sq.Or{sq.Eq{"SchemeGuest": false}, sq.Expr("SchemeGuest IS NULL")}).ToSql() if err != nil { - return errors.Wrap(err, "team_tosql") + return nil, errors.Wrap(err, "team_tosql") } - if _, err = s.GetMasterX().Exec(query, args...); err != nil { - return errors.Wrap(err, "failed to update TeamMembers") + if _, err = transaction.Exec(query, args...); err != nil { + return nil, errors.Wrap(err, "failed to update TeamMembers") } - return nil + if err = transaction.Commit(); err != nil { + return nil, errors.Wrap(err, "commit_transaction") + } + + return updatedMembers, nil } func applyTeamMemberViewRestrictionsFilter(query sq.SelectBuilder, restrictions *model.ViewUsersRestrictions) sq.SelectBuilder { diff --git a/server/channels/store/store.go b/server/channels/store/store.go index 443363efeb39a..0463e500fda18 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -168,7 +168,8 @@ type TeamStore interface { // UpdateMembersRole sets all of the given team members to admins and all of the other members of the team to // non-admin members. - UpdateMembersRole(teamID string, userIDs []string) error + // It returns the list of userIDs whose roles got updated. + UpdateMembersRole(teamID string, adminIDs []string) ([]*model.TeamMember, error) // GroupSyncedTeamCount returns the count of non-deleted group-constrained teams. GroupSyncedTeamCount() (int64, error) @@ -300,7 +301,8 @@ type ChannelStore interface { // UpdateMembersRole sets all of the given team members to admins and all of the other members of the team to // non-admin members. - UpdateMembersRole(channelID string, userIDs []string) error + // It returns the list of userIDs whose roles got updated. + UpdateMembersRole(channelID string, userIDs []string) ([]*model.ChannelMember, error) // GroupSyncedChannelCount returns the count of non-deleted group-constrained channels. GroupSyncedChannelCount() (int64, error) diff --git a/server/channels/store/storetest/channel_store.go b/server/channels/store/storetest/channel_store.go index a9877e0eaab6b..5a226902adf37 100644 --- a/server/channels/store/storetest/channel_store.go +++ b/server/channels/store/storetest/channel_store.go @@ -3854,6 +3854,7 @@ func testChannelStoreGetAllChannels(t *testing.T, rctx request.CTX, ss store.Sto c1.DisplayName = "Channel1" + model.NewId() c1.Name = NewTestId() c1.Type = model.ChannelTypeOpen + c1.GroupConstrained = model.NewPointer(true) _, nErr := ss.Channel().Save(rctx, &c1, -1) require.NoError(t, nErr) @@ -3939,6 +3940,19 @@ func testChannelStoreGetAllChannels(t *testing.T, rctx request.CTX, ss store.Sto list, nErr = ss.Channel().GetAllChannels(0, 10, store.ChannelSearchOpts{NotAssociatedToGroup: group.Id}) require.NoError(t, nErr) assert.Len(t, list, 1) + assert.Equal(t, c3.Id, list[0].Id) + + // GroupConstrained + list, nErr = ss.Channel().GetAllChannels(0, 10, store.ChannelSearchOpts{GroupConstrained: true}) + require.NoError(t, nErr) + require.Len(t, list, 1) + assert.Equal(t, c1.Id, list[0].Id) + + // ExcludeGroupConstrained + list, nErr = ss.Channel().GetAllChannels(0, 10, store.ChannelSearchOpts{ExcludeGroupConstrained: true}) + require.NoError(t, nErr) + require.Len(t, list, 1) + assert.Equal(t, c3.Id, list[0].Id) // Exclude channel names list, nErr = ss.Channel().GetAllChannels(0, 10, store.ChannelSearchOpts{ExcludeChannelNames: []string{c1.Name}}) diff --git a/server/channels/store/storetest/group_store.go b/server/channels/store/storetest/group_store.go index 8bb0d465f015d..3e8a2f671220d 100644 --- a/server/channels/store/storetest/group_store.go +++ b/server/channels/store/storetest/group_store.go @@ -81,7 +81,7 @@ func TestGroupStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("AdminRoleGroupsForSyncableMember_Team", func(t *testing.T) { groupTestAdminRoleGroupsForSyncableMemberTeam(t, rctx, ss) }) t.Run("PermittedSyncableAdmins_Team", func(t *testing.T) { groupTestPermittedSyncableAdminsTeam(t, rctx, ss) }) t.Run("PermittedSyncableAdmins_Channel", func(t *testing.T) { groupTestPermittedSyncableAdminsChannel(t, rctx, ss) }) - t.Run("UpdateMembersRole_Team", func(t *testing.T) { groupTestpUpdateMembersRoleTeam(t, rctx, ss) }) + t.Run("UpdateMembersRole_Team", func(t *testing.T) { groupTestUpdateMembersRoleTeam(t, rctx, ss) }) t.Run("UpdateMembersRole_Channel", func(t *testing.T) { groupTestpUpdateMembersRoleChannel(t, rctx, ss) }) t.Run("GroupCount", func(t *testing.T) { groupTestGroupCount(t, rctx, ss) }) @@ -4739,7 +4739,7 @@ func groupTestPermittedSyncableAdminsChannel(t *testing.T, rctx request.CTX, ss require.ElementsMatch(t, []string{user3.Id}, actualUserIDs) } -func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.Store) { +func groupTestUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.Store) { team := &model.Team{ DisplayName: "Name", Description: "Some description", @@ -4759,6 +4759,7 @@ func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.St } user1, err = ss.User().Save(rctx, user1) require.NoError(t, err) + t.Log("Created user1", user1.Id) user2 := &model.User{ Email: MakeEmail(), @@ -4766,6 +4767,7 @@ func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.St } user2, err = ss.User().Save(rctx, user2) require.NoError(t, err) + t.Log("Created user2", user2.Id) user3 := &model.User{ Email: MakeEmail(), @@ -4773,6 +4775,7 @@ func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.St } user3, err = ss.User().Save(rctx, user3) require.NoError(t, err) + t.Log("Created user3", user3.Id) user4 := &model.User{ Email: MakeEmail(), @@ -4780,6 +4783,7 @@ func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.St } user4, err = ss.User().Save(rctx, user4) require.NoError(t, err) + t.Log("Created user4", user4.Id) for _, user := range []*model.User{user1, user2, user3} { _, nErr := ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: team.Id, UserId: user.Id}, 9999) @@ -4790,53 +4794,73 @@ func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.St require.NoError(t, nErr) tests := []struct { - testName string - inUserIDs []string - targetSchemeAdminValue bool + testName string + newAdmins []string + expectedUpdatedUsers []string }{ { - "Given users are admins", + "Two new admins", + []string{user1.Id, user2.Id}, []string{user1.Id, user2.Id}, - true, }, { - "Given users are members", + "Demote one admin", + []string{user1.Id}, []string{user2.Id}, - false, }, { - "Non-given users are admins", - []string{user2.Id}, - false, + "Operation is idempotent", + []string{user1.Id}, + nil, }, { - "Non-given users are members", - []string{user2.Id}, - false, + "Promote a team member", + []string{user1.Id, user3.Id}, + []string{user3.Id}, + }, + { + "Guests never get promoted", + []string{user1.Id, user3.Id, user4.Id}, + nil, }, } for _, tt := range tests { t.Run(tt.testName, func(t *testing.T) { - err = ss.Team().UpdateMembersRole(team.Id, tt.inUserIDs) + var updatedMembers []*model.TeamMember + updatedMembers, err = ss.Team().UpdateMembersRole(team.Id, tt.newAdmins) require.NoError(t, err) - members, err := ss.Team().GetMembers(team.Id, 0, 100, nil) - require.NoError(t, err) - require.GreaterOrEqual(t, len(members), 4) // sanity check for team membership + var updatedUserIDs []string + for _, member := range updatedMembers { + assert.False(t, member.SchemeGuest, fmt.Sprintf("userID: %s", member.UserId)) - for _, member := range members { - if slices.Contains(tt.inUserIDs, member.UserId) { - require.True(t, member.SchemeAdmin) + if slices.Contains(tt.newAdmins, member.UserId) { + assert.True(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId)) } else { - require.False(t, member.SchemeAdmin) + assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId)) } + updatedUserIDs = append(updatedUserIDs, member.UserId) + } + assert.ElementsMatch(t, tt.expectedUpdatedUsers, updatedUserIDs) + + members, err := ss.Team().GetMembers(team.Id, 0, 100, nil) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(members), 4) // sanity check for team membership + + for _, member := range members { // Ensure guest account never changes. if member.UserId == user4.Id { - require.False(t, member.SchemeUser) - require.False(t, member.SchemeAdmin) - require.True(t, member.SchemeGuest) + assert.False(t, member.SchemeUser, fmt.Sprintf("userID: %s", member.UserId)) + assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId)) + assert.True(t, member.SchemeGuest, fmt.Sprintf("userID: %s", member.UserId)) + } else { + if slices.Contains(tt.newAdmins, member.UserId) { + assert.True(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId)) + } else { + assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId)) + } } } }) @@ -4859,6 +4883,7 @@ func groupTestpUpdateMembersRoleChannel(t *testing.T, rctx request.CTX, ss store } user1, err = ss.User().Save(rctx, user1) require.NoError(t, err) + t.Log("Created user1", user1.Id) user2 := &model.User{ Email: MakeEmail(), @@ -4866,6 +4891,7 @@ func groupTestpUpdateMembersRoleChannel(t *testing.T, rctx request.CTX, ss store } user2, err = ss.User().Save(rctx, user2) require.NoError(t, err) + t.Log("Created user2", user2.Id) user3 := &model.User{ Email: MakeEmail(), @@ -4873,6 +4899,7 @@ func groupTestpUpdateMembersRoleChannel(t *testing.T, rctx request.CTX, ss store } user3, err = ss.User().Save(rctx, user3) require.NoError(t, err) + t.Log("Created user3", user3.Id) user4 := &model.User{ Email: MakeEmail(), @@ -4880,6 +4907,7 @@ func groupTestpUpdateMembersRoleChannel(t *testing.T, rctx request.CTX, ss store } user4, err = ss.User().Save(rctx, user4) require.NoError(t, err) + t.Log("Created user4", user4.Id) for _, user := range []*model.User{user1, user2, user3} { _, err = ss.Channel().SaveMember(rctx, &model.ChannelMember{ @@ -4899,54 +4927,73 @@ func groupTestpUpdateMembersRoleChannel(t *testing.T, rctx request.CTX, ss store require.NoError(t, err) tests := []struct { - testName string - inUserIDs []string - targetSchemeAdminValue bool + testName string + newAdmins []string + expectedUpdatedUsers []string }{ { - "Given users are admins", + "Two new admins", + []string{user1.Id, user2.Id}, []string{user1.Id, user2.Id}, - true, }, { - "Given users are members", + "Demote one admin", + []string{user1.Id}, []string{user2.Id}, - false, }, { - "Non-given users are admins", - []string{user2.Id}, - false, + "Operation is idempotent", + []string{user1.Id}, + nil, }, { - "Non-given users are members", - []string{user2.Id}, - false, + "Promote a team member", + []string{user1.Id, user3.Id}, + []string{user3.Id}, + }, + { + "Guests never get promoted", + []string{user1.Id, user3.Id, user4.Id}, + nil, }, } for _, tt := range tests { t.Run(tt.testName, func(t *testing.T) { - err = ss.Channel().UpdateMembersRole(channel.Id, tt.inUserIDs) + var updatedMemmbers []*model.ChannelMember + updatedMemmbers, err = ss.Channel().UpdateMembersRole(channel.Id, tt.newAdmins) require.NoError(t, err) - members, err := ss.Channel().GetMembers(channel.Id, 0, 100) - require.NoError(t, err) + var updatedUserIDs []string + for _, member := range updatedMemmbers { + assert.False(t, member.SchemeGuest, fmt.Sprintf("userID: %s", member.UserId)) - require.GreaterOrEqual(t, len(members), 4) // sanity check for channel membership - - for _, member := range members { - if slices.Contains(tt.inUserIDs, member.UserId) { - require.True(t, member.SchemeAdmin) + if slices.Contains(tt.newAdmins, member.UserId) { + assert.True(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId)) } else { - require.False(t, member.SchemeAdmin) + assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId)) } + updatedUserIDs = append(updatedUserIDs, member.UserId) + } + assert.ElementsMatch(t, tt.expectedUpdatedUsers, updatedUserIDs) + + members, err := ss.Channel().GetMembers(channel.Id, 0, 100) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(members), 4) // sanity check for channel membership + + for _, member := range members { // Ensure guest account never changes. if member.UserId == user4.Id { - require.False(t, member.SchemeUser) - require.False(t, member.SchemeAdmin) - require.True(t, member.SchemeGuest) + assert.False(t, member.SchemeUser, fmt.Sprintf("userID: %s", member.UserId)) + assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId)) + assert.True(t, member.SchemeGuest, fmt.Sprintf("userID: %s", member.UserId)) + } else { + if slices.Contains(tt.newAdmins, member.UserId) { + assert.True(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId)) + } else { + assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId)) + } } } }) diff --git a/server/channels/store/storetest/mocks/ChannelStore.go b/server/channels/store/storetest/mocks/ChannelStore.go index 2661df49ce81a..489151e9cbe75 100644 --- a/server/channels/store/storetest/mocks/ChannelStore.go +++ b/server/channels/store/storetest/mocks/ChannelStore.go @@ -2915,21 +2915,33 @@ func (_m *ChannelStore) UpdateMemberNotifyProps(channelID string, userID string, } // UpdateMembersRole provides a mock function with given fields: channelID, userIDs -func (_m *ChannelStore) UpdateMembersRole(channelID string, userIDs []string) error { +func (_m *ChannelStore) UpdateMembersRole(channelID string, userIDs []string) ([]*model.ChannelMember, error) { ret := _m.Called(channelID, userIDs) if len(ret) == 0 { panic("no return value specified for UpdateMembersRole") } - var r0 error - if rf, ok := ret.Get(0).(func(string, []string) error); ok { + var r0 []*model.ChannelMember + var r1 error + if rf, ok := ret.Get(0).(func(string, []string) ([]*model.ChannelMember, error)); ok { + return rf(channelID, userIDs) + } + if rf, ok := ret.Get(0).(func(string, []string) []*model.ChannelMember); ok { r0 = rf(channelID, userIDs) } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.ChannelMember) + } } - return r0 + if rf, ok := ret.Get(1).(func(string, []string) error); ok { + r1 = rf(channelID, userIDs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // UpdateMultipleMembers provides a mock function with given fields: members diff --git a/server/channels/store/storetest/mocks/TeamStore.go b/server/channels/store/storetest/mocks/TeamStore.go index 2c23ba3077fab..60493056c62a7 100644 --- a/server/channels/store/storetest/mocks/TeamStore.go +++ b/server/channels/store/storetest/mocks/TeamStore.go @@ -1336,22 +1336,34 @@ func (_m *TeamStore) UpdateMember(rctx request.CTX, member *model.TeamMember) (* return r0, r1 } -// UpdateMembersRole provides a mock function with given fields: teamID, userIDs -func (_m *TeamStore) UpdateMembersRole(teamID string, userIDs []string) error { - ret := _m.Called(teamID, userIDs) +// UpdateMembersRole provides a mock function with given fields: teamID, adminIDs +func (_m *TeamStore) UpdateMembersRole(teamID string, adminIDs []string) ([]*model.TeamMember, error) { + ret := _m.Called(teamID, adminIDs) if len(ret) == 0 { panic("no return value specified for UpdateMembersRole") } - var r0 error - if rf, ok := ret.Get(0).(func(string, []string) error); ok { - r0 = rf(teamID, userIDs) + var r0 []*model.TeamMember + var r1 error + if rf, ok := ret.Get(0).(func(string, []string) ([]*model.TeamMember, error)); ok { + return rf(teamID, adminIDs) + } + if rf, ok := ret.Get(0).(func(string, []string) []*model.TeamMember); ok { + r0 = rf(teamID, adminIDs) } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.TeamMember) + } } - return r0 + if rf, ok := ret.Get(1).(func(string, []string) error); ok { + r1 = rf(teamID, adminIDs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // UpdateMultipleMembers provides a mock function with given fields: members diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index ef4bf76962019..61dde64b5bc67 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -2366,10 +2366,10 @@ func (s *TimerLayerChannelStore) UpdateMemberNotifyProps(channelID string, userI return result, err } -func (s *TimerLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) error { +func (s *TimerLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) ([]*model.ChannelMember, error) { start := time.Now() - err := s.ChannelStore.UpdateMembersRole(channelID, userIDs) + result, err := s.ChannelStore.UpdateMembersRole(channelID, userIDs) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -2379,7 +2379,7 @@ func (s *TimerLayerChannelStore) UpdateMembersRole(channelID string, userIDs []s } s.Root.Metrics.ObserveStoreMethodDuration("ChannelStore.UpdateMembersRole", success, elapsed) } - return err + return result, err } func (s *TimerLayerChannelStore) UpdateMultipleMembers(members []*model.ChannelMember) ([]*model.ChannelMember, error) { @@ -9495,10 +9495,10 @@ func (s *TimerLayerTeamStore) UpdateMember(rctx request.CTX, member *model.TeamM return result, err } -func (s *TimerLayerTeamStore) UpdateMembersRole(teamID string, userIDs []string) error { +func (s *TimerLayerTeamStore) UpdateMembersRole(teamID string, adminIDs []string) ([]*model.TeamMember, error) { start := time.Now() - err := s.TeamStore.UpdateMembersRole(teamID, userIDs) + result, err := s.TeamStore.UpdateMembersRole(teamID, adminIDs) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -9508,7 +9508,7 @@ func (s *TimerLayerTeamStore) UpdateMembersRole(teamID string, userIDs []string) } s.Root.Metrics.ObserveStoreMethodDuration("TeamStore.UpdateMembersRole", success, elapsed) } - return err + return result, err } func (s *TimerLayerTeamStore) UpdateMultipleMembers(members []*model.TeamMember) ([]*model.TeamMember, error) { diff --git a/webapp/channels/src/components/channel_select/index.ts b/webapp/channels/src/components/channel_select/index.ts index 08d71804dd207..faac736179ade 100644 --- a/webapp/channels/src/components/channel_select/index.ts +++ b/webapp/channels/src/components/channel_select/index.ts @@ -17,7 +17,8 @@ const getMyChannelsSorted = createSelector( getMyChannels, getCurrentUserLocale, (channels, locale) => { - return [...channels].sort(sortChannelsByTypeAndDisplayName.bind(null, locale)); + const activeChannels = channels.filter((channel) => channel.delete_at === 0); + return [...activeChannels].sort(sortChannelsByTypeAndDisplayName.bind(null, locale)); }, ); diff --git a/webapp/channels/src/components/root/root_redirect/index.ts b/webapp/channels/src/components/root/root_redirect/index.ts index d361a82c98ec2..1309717653c09 100644 --- a/webapp/channels/src/components/root/root_redirect/index.ts +++ b/webapp/channels/src/components/root/root_redirect/index.ts @@ -7,6 +7,7 @@ import type {Dispatch} from 'redux'; import {getFirstAdminSetupComplete} from 'mattermost-redux/actions/general'; import {getIsOnboardingFlowEnabled} from 'mattermost-redux/selectors/entities/preferences'; +import {getActiveTeamsList} from 'mattermost-redux/selectors/entities/teams'; import {getCurrentUserId, isCurrentUserSystemAdmin, isFirstAdmin} from 'mattermost-redux/selectors/entities/users'; import type {GlobalState} from 'types/store'; @@ -15,6 +16,7 @@ import RootRedirect from './root_redirect'; function mapStateToProps(state: GlobalState) { const onboardingFlowEnabled = getIsOnboardingFlowEnabled(state); + const teams = getActiveTeamsList(state); let isElegibleForFirstAdmingOnboarding = onboardingFlowEnabled; if (isElegibleForFirstAdmingOnboarding) { isElegibleForFirstAdmingOnboarding = isCurrentUserSystemAdmin(state); @@ -23,6 +25,7 @@ function mapStateToProps(state: GlobalState) { currentUserId: getCurrentUserId(state), isElegibleForFirstAdmingOnboarding, isFirstAdmin: isFirstAdmin(state), + areThereTeams: Boolean(teams.length), }; } diff --git a/webapp/channels/src/components/root/root_redirect/root_redirect.test.tsx b/webapp/channels/src/components/root/root_redirect/root_redirect.test.tsx new file mode 100644 index 0000000000000..8bb5278ec9647 --- /dev/null +++ b/webapp/channels/src/components/root/root_redirect/root_redirect.test.tsx @@ -0,0 +1,173 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +import {createMemoryHistory} from 'history'; +import React from 'react'; +import type {RouteComponentProps} from 'react-router-dom'; +import {Redirect} from 'react-router-dom'; + +import {getFirstAdminSetupComplete as getFirstAdminSetupCompleteAction} from 'mattermost-redux/actions/general'; + +import * as GlobalActions from 'actions/global_actions'; + +import {renderWithContext, waitFor} from 'tests/react_testing_utils'; + +import RootRedirect from './root_redirect'; +import type {Props} from './root_redirect'; + +jest.mock('actions/global_actions', () => ({ + redirectUserToDefaultTeam: jest.fn(), +})); + +jest.mock('mattermost-redux/actions/general', () => ({ + getFirstAdminSetupComplete: jest.fn(() => + Promise.resolve({ + data: true, + }), + ), +})); + +jest.mock('react-router-dom', () => { + const actual = jest.requireActual('react-router-dom'); + return { + ...actual, + Redirect: jest.fn(() => null), + }; +}); + +describe('components/RootRedirect', () => { + const baseProps: Props = { + currentUserId: '', + isElegibleForFirstAdmingOnboarding: false, + isFirstAdmin: false, + areThereTeams: false, + actions: { + getFirstAdminSetupComplete: getFirstAdminSetupCompleteAction as jest.Mock, + }, + }; + + const defaultProps = { + ...baseProps, + location: { + pathname: '/', + }, + } as Props & RouteComponentProps; + + afterEach(() => { + jest.clearAllMocks(); + }); + + test('should redirect to /login when currentUserId is empty', () => { + renderWithContext(); + + expect(Redirect).toHaveBeenCalledTimes(1); + expect(Redirect).toHaveBeenCalledWith( + expect.objectContaining({ + to: expect.objectContaining({ + pathname: '/login', + }), + }), + {}, + ); + }); + + test('should call GlobalActions.redirectUserToDefaultTeam when user is logged in and not eligible for first admin onboarding', () => { + const props = { + ...defaultProps, + currentUserId: 'test-user-id', + isElegibleForFirstAdmingOnboarding: false, + }; + + renderWithContext(); + + expect(GlobalActions.redirectUserToDefaultTeam).toHaveBeenCalledTimes(1); + }); + + test('should redirect to preparing-workspace when eligible for first admin onboarding and no teams created', async () => { + const history = createMemoryHistory({initialEntries: ['/']}); + const mockHistoryPush = jest.spyOn(history, 'push'); + + const props = { + currentUserId: 'test-user-id', + isElegibleForFirstAdmingOnboarding: true, + isFirstAdmin: true, + areThereTeams: false, + actions: { + getFirstAdminSetupComplete: jest.fn().mockResolvedValue({data: false}), + }, + }; + + renderWithContext(, {}, {history}); + + expect(props.actions.getFirstAdminSetupComplete).toHaveBeenCalledTimes(1); + + await waitFor(() => { + expect(mockHistoryPush).toHaveBeenCalledWith('/preparing-workspace'); + }); + }); + + test('should NOT redirect to preparing-workspace when there are teams created, even if system value for first admin onboarding complete is false', async () => { + const history = createMemoryHistory({initialEntries: ['/']}); + + const props = { + ...defaultProps, + currentUserId: 'test-user-id', + isElegibleForFirstAdmingOnboarding: true, + isFirstAdmin: true, + areThereTeams: true, + actions: { + getFirstAdminSetupComplete: jest.fn().mockResolvedValue({data: false}), + }, + }; + + renderWithContext(, {}, {history}); + + expect(props.actions.getFirstAdminSetupComplete).toHaveBeenCalledTimes(1); + + await waitFor(() => { + expect(GlobalActions.redirectUserToDefaultTeam).toHaveBeenCalledTimes(1); + }); + }); + + test('should redirect to default team when first admin setup is complete', async () => { + const props = { + ...defaultProps, + currentUserId: 'test-user-id', + isElegibleForFirstAdmingOnboarding: true, + isFirstAdmin: true, + areThereTeams: false, + actions: { + getFirstAdminSetupComplete: jest.fn().mockResolvedValue({data: true}), + }, + }; + + renderWithContext(); + + expect(props.actions.getFirstAdminSetupComplete).toHaveBeenCalledTimes(1); + + await waitFor(() => { + expect(GlobalActions.redirectUserToDefaultTeam).toHaveBeenCalledTimes(1); + }); + }); + + test('should redirect to default team when not first admin or teams exist', async () => { + const props = { + ...defaultProps, + currentUserId: 'test-user-id', + isElegibleForFirstAdmingOnboarding: true, + isFirstAdmin: false, + areThereTeams: true, + actions: { + getFirstAdminSetupComplete: jest.fn().mockResolvedValue({data: false}), + }, + }; + + renderWithContext(); + + expect(props.actions.getFirstAdminSetupComplete).toHaveBeenCalledTimes(1); + + await waitFor(() => { + expect(GlobalActions.redirectUserToDefaultTeam).toHaveBeenCalledTimes(1); + }); + }); +}); diff --git a/webapp/channels/src/components/root/root_redirect/root_redirect.tsx b/webapp/channels/src/components/root/root_redirect/root_redirect.tsx index c528a91c3aeb9..3a1db8624e92c 100644 --- a/webapp/channels/src/components/root/root_redirect/root_redirect.tsx +++ b/webapp/channels/src/components/root/root_redirect/root_redirect.tsx @@ -13,6 +13,7 @@ export type Props = { currentUserId: string; location?: Location; isFirstAdmin: boolean; + areThereTeams: boolean; actions: { getFirstAdminSetupComplete: () => Promise; }; @@ -26,7 +27,7 @@ export default function RootRedirect(props: Props) { if (props.isElegibleForFirstAdmingOnboarding) { props.actions.getFirstAdminSetupComplete().then((firstAdminCompletedSignup) => { // root.tsx ensures admin profiles are eventually loaded - if (firstAdminCompletedSignup.data === false && props.isFirstAdmin) { + if (firstAdminCompletedSignup.data === false && props.isFirstAdmin && !props.areThereTeams) { history.push('/preparing-workspace'); } else { GlobalActions.redirectUserToDefaultTeam();