diff --git a/server/.golangci.yml b/server/.golangci.yml
index d37b229915830..7b70008418241 100644
--- a/server/.golangci.yml
+++ b/server/.golangci.yml
@@ -81,26 +81,19 @@ issues:
channels/api4/file_test.go|\
channels/api4/group.go|\
channels/api4/group_local.go|\
- channels/api4/group_test.go|\
channels/api4/handlers_test.go|\
channels/api4/import_test.go|\
- channels/api4/integration_action.go|\
channels/api4/integration_action_test.go|\
channels/api4/ip_filtering.go|\
channels/api4/ip_filtering_test.go|\
- channels/api4/job.go|\
channels/api4/job_test.go|\
- channels/api4/ldap.go|\
channels/api4/license.go|\
channels/api4/license_local.go|\
channels/api4/license_test.go|\
- channels/api4/oauth.go|\
channels/api4/oauth_test.go|\
channels/api4/outgoing_oauth_connection_test.go|\
- channels/api4/permission.go|\
channels/api4/plugin.go|\
channels/api4/plugin_test.go|\
- channels/api4/post.go|\
channels/api4/post_test.go|\
channels/api4/preference_test.go|\
channels/api4/reaction.go|\
diff --git a/server/channels/api4/group_test.go b/server/channels/api4/group_test.go
index 35eaeac27da7a..66f541aa94b89 100644
--- a/server/channels/api4/group_test.go
+++ b/server/channels/api4/group_test.go
@@ -60,7 +60,8 @@ func TestGetGroup(t *testing.T) {
require.Error(t, err)
CheckBadRequestStatus(t, response)
- th.SystemAdminClient.Logout(context.Background())
+ _, err = th.SystemAdminClient.Logout(context.Background())
+ require.NoError(t, err)
_, response, err = th.SystemAdminClient.GetGroup(context.Background(), group.Id, "")
require.Error(t, err)
CheckUnauthorizedStatus(t, response)
@@ -168,7 +169,8 @@ func TestCreateGroup(t *testing.T) {
require.Error(t, err)
CheckBadRequestStatus(t, response)
- th.SystemAdminClient.Logout(context.Background())
+ _, err = th.SystemAdminClient.Logout(context.Background())
+ require.NoError(t, err)
_, response, err = th.SystemAdminClient.CreateGroup(context.Background(), g)
require.Error(t, err)
CheckUnauthorizedStatus(t, response)
@@ -347,7 +349,8 @@ func TestPatchGroup(t *testing.T) {
require.Error(t, err)
CheckBadRequestStatus(t, response)
- th.SystemAdminClient.Logout(context.Background())
+ _, err = th.SystemAdminClient.Logout(context.Background())
+ require.NoError(t, err)
_, response, err = th.SystemAdminClient.PatchGroup(context.Background(), group.Id, gp)
require.Error(t, err)
CheckUnauthorizedStatus(t, response)
@@ -440,7 +443,8 @@ func TestLinkGroupTeam(t *testing.T) {
})
t.Run("System manager without invite_user are allowed to link", func(t *testing.T) {
- th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password)
+ _, _, err = th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password)
+ require.NoError(t, err)
groupSyncable, response, err = th.SystemManagerClient.LinkGroupSyncable(context.Background(), g.Id, th.BasicTeam.Id, model.GroupSyncableTypeTeam, patch)
require.NoError(t, err)
CheckCreatedStatus(t, response)
@@ -564,7 +568,8 @@ func TestLinkGroupChannel(t *testing.T) {
})
t.Run("System manager without invite_user are allowed to link", func(t *testing.T) {
- th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password)
+ _, _, err = th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password)
+ require.NoError(t, err)
groupSyncable, response, err = th.SystemManagerClient.LinkGroupSyncable(context.Background(), g.Id, th.BasicChannel.Id, model.GroupSyncableTypeChannel, patch)
require.NoError(t, err)
CheckCreatedStatus(t, response)
@@ -684,7 +689,8 @@ func TestUnlinkGroupTeam(t *testing.T) {
})
t.Run("System manager without invite_user are allowed to link", func(t *testing.T) {
- th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password)
+ _, _, err = th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password)
+ require.NoError(t, err)
response, err = th.SystemManagerClient.UnlinkGroupSyncable(context.Background(), g.Id, th.BasicTeam.Id, model.GroupSyncableTypeTeam)
require.NoError(t, err)
CheckOKStatus(t, response)
@@ -802,7 +808,8 @@ func TestUnlinkGroupChannel(t *testing.T) {
})
t.Run("System manager without invite_user are allowed to link", func(t *testing.T) {
- th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password)
+ _, _, err = th.SystemManagerClient.Login(context.Background(), th.SystemManagerUser.Email, th.SystemManagerUser.Password)
+ require.NoError(t, err)
response, err = th.SystemManagerClient.UnlinkGroupSyncable(context.Background(), g.Id, th.BasicChannel.Id, model.GroupSyncableTypeChannel)
require.NoError(t, err)
CheckOKStatus(t, response)
@@ -881,7 +888,8 @@ func TestGetGroupTeam(t *testing.T) {
require.Error(t, err)
CheckBadRequestStatus(t, response)
- th.SystemAdminClient.Logout(context.Background())
+ _, err = th.SystemAdminClient.Logout(context.Background())
+ require.NoError(t, err)
_, response, err = th.SystemAdminClient.GetGroupSyncable(context.Background(), g.Id, th.BasicTeam.Id, model.GroupSyncableTypeTeam, "")
require.Error(t, err)
CheckUnauthorizedStatus(t, response)
@@ -943,7 +951,8 @@ func TestGetGroupChannel(t *testing.T) {
require.Error(t, err)
CheckBadRequestStatus(t, response)
- th.SystemAdminClient.Logout(context.Background())
+ _, err = th.SystemAdminClient.Logout(context.Background())
+ require.NoError(t, err)
_, response, err = th.SystemAdminClient.GetGroupSyncable(context.Background(), g.Id, th.BasicChannel.Id, model.GroupSyncableTypeChannel, "")
require.Error(t, err)
CheckUnauthorizedStatus(t, response)
@@ -996,7 +1005,8 @@ func TestGetGroupTeams(t *testing.T) {
assert.Len(t, groupSyncables, 10)
- th.SystemAdminClient.Logout(context.Background())
+ _, err = th.SystemAdminClient.Logout(context.Background())
+ require.NoError(t, err)
_, response, err = th.SystemAdminClient.GetGroupSyncables(context.Background(), g.Id, model.GroupSyncableTypeTeam, "")
require.Error(t, err)
CheckUnauthorizedStatus(t, response)
@@ -1048,7 +1058,8 @@ func TestGetGroupChannels(t *testing.T) {
assert.Len(t, groupSyncables, 10)
- th.SystemAdminClient.Logout(context.Background())
+ _, err = th.SystemAdminClient.Logout(context.Background())
+ require.NoError(t, err)
_, response, err = th.SystemAdminClient.GetGroupSyncables(context.Background(), g.Id, model.GroupSyncableTypeChannel, "")
require.Error(t, err)
CheckUnauthorizedStatus(t, response)
@@ -1120,7 +1131,8 @@ func TestPatchGroupTeam(t *testing.T) {
require.Error(t, err)
CheckBadRequestStatus(t, response)
- th.SystemAdminClient.Logout(context.Background())
+ _, err = th.SystemAdminClient.Logout(context.Background())
+ require.NoError(t, err)
_, response, err = th.SystemAdminClient.PatchGroupSyncable(context.Background(), g.Id, th.BasicTeam.Id, model.GroupSyncableTypeTeam, patch)
require.Error(t, err)
CheckUnauthorizedStatus(t, response)
@@ -1203,7 +1215,8 @@ func TestPatchGroupChannel(t *testing.T) {
require.Error(t, err)
CheckBadRequestStatus(t, response)
- th.SystemAdminClient.Logout(context.Background())
+ _, err = th.SystemAdminClient.Logout(context.Background())
+ require.NoError(t, err)
_, response, err = th.SystemAdminClient.PatchGroupSyncable(context.Background(), g.Id, th.BasicChannel.Id, model.GroupSyncableTypeChannel, patch)
require.Error(t, err)
CheckUnauthorizedStatus(t, response)
@@ -1380,8 +1393,10 @@ func TestGetGroupsAssociatedToChannelsByTeam(t *testing.T) {
t.Run("should return forbidden when the user doesn't have the right permissions", func(t *testing.T) {
require.Nil(t, th.App.RemoveUserFromTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id))
- defer th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id)
-
+ defer func() {
+ _, _, appErr := th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id)
+ require.Nil(t, appErr)
+ }()
groups, resp, err := th.Client.GetGroupsAssociatedToChannelsByTeam(context.Background(), th.BasicTeam.Id, opts)
require.Error(t, err)
CheckForbiddenStatus(t, resp)
@@ -1426,7 +1441,8 @@ func TestGetGroupsByTeam(t *testing.T) {
CheckBadRequestStatus(t, response)
})
- th.App.Srv().RemoveLicense()
+ appErr := th.App.Srv().RemoveLicense()
+ require.Nil(t, appErr)
th.TestForSystemAdminAndLocal(t, func(t *testing.T, client *model.Client4) {
_, _, response, err := client.GetGroupsByTeam(context.Background(), th.BasicTeam.Id, opts)
@@ -1487,7 +1503,10 @@ func TestGetGroupsByTeam(t *testing.T) {
t.Run("user can't fetch groups if it's not part of the team", func(t *testing.T) {
require.Nil(t, th.App.RemoveUserFromTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id))
- defer th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id)
+ defer func() {
+ _, _, appErr := th.App.AddUserToTeam(th.Context, th.BasicTeam.Id, th.BasicUser.Id, th.SystemAdminUser.Id)
+ require.Nil(t, appErr)
+ }()
groups, _, response, err := th.Client.GetGroupsByTeam(context.Background(), th.BasicTeam.Id, opts)
require.Error(t, err)
@@ -1580,7 +1599,8 @@ func TestGetGroups(t *testing.T) {
assert.Equal(t, groups[0].Id, group.Id)
// delete group, should still return
- th.App.DeleteGroup(group.Id)
+ _, appErr = th.App.DeleteGroup(group.Id)
+ require.Nil(t, appErr)
groups, _, err = th.Client.GetGroups(context.Background(), opts)
assert.NoError(t, err)
assert.Len(t, groups, 1)
@@ -1721,14 +1741,18 @@ func TestGetGroupsByUserId(t *testing.T) {
assert.ElementsMatch(t, []*model.Group{group1, group2}, groups)
// test permissions
- th.Client.Logout(context.Background())
- th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password)
+ _, err = th.Client.Logout(context.Background())
+ require.NoError(t, err)
+ _, _, err = th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password)
+ require.NoError(t, err)
_, response, err = th.Client.GetGroupsByUserId(context.Background(), user1.Id)
require.Error(t, err)
CheckForbiddenStatus(t, response)
- th.Client.Logout(context.Background())
- th.Client.Login(context.Background(), user1.Email, user1.Password)
+ _, err = th.Client.Logout(context.Background())
+ require.NoError(t, err)
+ _, _, err = th.Client.Login(context.Background(), user1.Email, user1.Password)
+ require.NoError(t, err)
groups, _, err = th.Client.GetGroupsByUserId(context.Background(), user1.Id)
require.NoError(t, err)
assert.ElementsMatch(t, []*model.Group{group1, group2}, groups)
@@ -1823,7 +1847,8 @@ func TestGetGroupStats(t *testing.T) {
th.App.Srv().SetLicense(model.NewTestLicense("ldap"))
t.Run("Requires manage system permission to access group stats", func(t *testing.T) {
- th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password)
+ _, _, err := th.Client.Login(context.Background(), th.BasicUser.Email, th.BasicUser.Password)
+ require.NoError(t, err)
_, response, err := th.Client.GetGroupStats(context.Background(), group.Id)
require.Error(t, err)
CheckForbiddenStatus(t, response)
diff --git a/server/channels/api4/integration_action.go b/server/channels/api4/integration_action.go
index eae516efabc8c..7076009b81c45 100644
--- a/server/channels/api4/integration_action.go
+++ b/server/channels/api4/integration_action.go
@@ -136,5 +136,7 @@ func submitDialog(c *Context, w http.ResponseWriter, r *http.Request) {
b, _ := json.Marshal(resp)
- w.Write(b)
+ if _, err := w.Write(b); err != nil {
+ c.Logger.Warn("Error while writing response", mlog.Err(err))
+ }
}
diff --git a/server/channels/api4/job.go b/server/channels/api4/job.go
index e577c7943c824..a794b1874f46a 100644
--- a/server/channels/api4/job.go
+++ b/server/channels/api4/job.go
@@ -210,7 +210,9 @@ func getJobs(c *Context, w http.ResponseWriter, r *http.Request) {
c.Err = model.NewAppError("getJobs", "api.marshal_error", nil, "", http.StatusInternalServerError).Wrap(err)
return
}
- w.Write(js)
+ if _, err := w.Write(js); err != nil {
+ c.Logger.Warn("Error while writing response", mlog.Err(err))
+ }
}
func getJobsByType(c *Context, w http.ResponseWriter, r *http.Request) {
@@ -241,7 +243,9 @@ func getJobsByType(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
- w.Write(js)
+ if _, err := w.Write(js); err != nil {
+ c.Logger.Warn("Error while writing response", mlog.Err(err))
+ }
}
func cancelJob(c *Context, w http.ResponseWriter, r *http.Request) {
diff --git a/server/channels/api4/ldap.go b/server/channels/api4/ldap.go
index 0669332328089..8d2d03b86014b 100644
--- a/server/channels/api4/ldap.go
+++ b/server/channels/api4/ldap.go
@@ -140,7 +140,9 @@ func getLdapGroups(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
- w.Write(b)
+ if _, err := w.Write(b); err != nil {
+ c.Logger.Warn("Error while writing response", mlog.Err(err))
+ }
}
func linkLdapGroup(c *Context, w http.ResponseWriter, r *http.Request) {
@@ -240,7 +242,9 @@ func linkLdapGroup(c *Context, w http.ResponseWriter, r *http.Request) {
auditRec.Success()
w.WriteHeader(status)
- w.Write(b)
+ if _, err := w.Write(b); err != nil {
+ c.Logger.Warn("Error while writing response", mlog.Err(err))
+ }
}
func unlinkLdapGroup(c *Context, w http.ResponseWriter, r *http.Request) {
diff --git a/server/channels/api4/oauth.go b/server/channels/api4/oauth.go
index a9f108100cb19..d65b73f23c57d 100644
--- a/server/channels/api4/oauth.go
+++ b/server/channels/api4/oauth.go
@@ -153,7 +153,9 @@ func getOAuthApps(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
- w.Write(js)
+ if _, err := w.Write(js); err != nil {
+ c.Logger.Warn("Error while writing response", mlog.Err(err))
+ }
}
func getOAuthApp(c *Context, w http.ResponseWriter, r *http.Request) {
@@ -308,5 +310,7 @@ func getAuthorizedOAuthApps(c *Context, w http.ResponseWriter, r *http.Request)
return
}
- w.Write(js)
+ if _, err := w.Write(js); err != nil {
+ c.Logger.Warn("Error while writing response", mlog.Err(err))
+ }
}
diff --git a/server/channels/api4/permission.go b/server/channels/api4/permission.go
index 2489aee2f0280..2566dff461c85 100644
--- a/server/channels/api4/permission.go
+++ b/server/channels/api4/permission.go
@@ -8,6 +8,7 @@ import (
"net/http"
"github.com/mattermost/mattermost/server/public/model"
+ "github.com/mattermost/mattermost/server/public/shared/mlog"
)
func (api *API) InitPermissions() {
@@ -25,5 +26,7 @@ func appendAncillaryPermissionsPost(c *Context, w http.ResponseWriter, r *http.R
c.SetJSONEncodingError(err)
return
}
- w.Write(b)
+ if _, err := w.Write(b); err != nil {
+ c.Logger.Warn("Error while writing response", mlog.Err(err))
+ }
}
diff --git a/server/channels/api4/post.go b/server/channels/api4/post.go
index d41f58e3f2142..559fec1c4df4a 100644
--- a/server/channels/api4/post.go
+++ b/server/channels/api4/post.go
@@ -1127,7 +1127,9 @@ func acknowledgePost(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
- w.Write(js)
+ if _, err := w.Write(js); err != nil {
+ c.Logger.Warn("Error while writing response", mlog.Err(err))
+ }
}
func unacknowledgePost(c *Context, w http.ResponseWriter, r *http.Request) {
@@ -1284,7 +1286,9 @@ func getFileInfosForPost(c *Context, w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "max-age=2592000, private")
w.Header().Set(model.HeaderEtagServer, model.GetEtagForFileInfos(infos))
- w.Write(js)
+ if _, err := w.Write(js); err != nil {
+ c.Logger.Warn("Error while writing response", mlog.Err(err))
+ }
}
func getPostInfo(c *Context, w http.ResponseWriter, r *http.Request) {
@@ -1305,7 +1309,9 @@ func getPostInfo(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
- w.Write(js)
+ if _, err := w.Write(js); err != nil {
+ c.Logger.Warn("Error while writing response", mlog.Err(err))
+ }
}
func hasPermittedWranglerRole(c *Context, user *model.User, channelMember *model.ChannelMember) bool {
diff --git a/server/channels/api4/status_test.go b/server/channels/api4/status_test.go
index 418a4b781962d..62a7c4ee649c7 100644
--- a/server/channels/api4/status_test.go
+++ b/server/channels/api4/status_test.go
@@ -5,13 +5,13 @@ package api4
import (
"context"
+ "strings"
"testing"
"time"
+ "github.com/mattermost/mattermost/server/public/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
-
- "github.com/mattermost/mattermost/server/public/model"
)
func TestGetUserStatus(t *testing.T) {
@@ -243,3 +243,175 @@ func TestUpdateUserStatus(t *testing.T) {
CheckUnauthorizedStatus(t, resp)
})
}
+
+func TestUpdateUserCustomStatus(t *testing.T) {
+ th := Setup(t).InitBasic()
+ defer th.TearDown()
+ client := th.Client
+
+ t.Run("set custom status", func(t *testing.T) {
+ toUpdateCustomStatus := &model.CustomStatus{
+ Emoji: "calendar", // Use a valid emoji name
+ Text: "My custom status",
+ }
+ _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus)
+ require.NoError(t, err)
+ CheckOKStatus(t, resp)
+
+ user, _, err := client.GetUser(context.Background(), th.BasicUser.Id, "")
+ require.NoError(t, err)
+ customStatus := user.GetCustomStatus()
+ require.NotNil(t, customStatus)
+ assert.Equal(t, toUpdateCustomStatus.Emoji, customStatus.Emoji)
+ assert.Equal(t, toUpdateCustomStatus.Text, customStatus.Text)
+ })
+
+ t.Run("update custom status with duration", func(t *testing.T) {
+ expiresAt := time.Now().Add(1 * time.Hour)
+ toUpdateCustomStatus := &model.CustomStatus{
+ Emoji: "palm_tree", // Use a valid emoji name
+ Text: "On vacation",
+ Duration: "date_and_time",
+ ExpiresAt: expiresAt,
+ }
+ _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus)
+ require.NoError(t, err)
+ CheckOKStatus(t, resp)
+
+ user, _, err := client.GetUser(context.Background(), th.BasicUser.Id, "")
+ require.NoError(t, err)
+ customStatus := user.GetCustomStatus()
+ require.NotNil(t, customStatus)
+ assert.Equal(t, toUpdateCustomStatus.Emoji, customStatus.Emoji)
+ assert.Equal(t, toUpdateCustomStatus.Text, customStatus.Text)
+ assert.Equal(t, toUpdateCustomStatus.Duration, customStatus.Duration)
+
+ require.NotNil(t, customStatus.ExpiresAt, "Expected ExpiresAt to be set")
+ // Check that ExpiresAt is within 5 seconds of the expected time
+ assert.WithinDuration(t, expiresAt, customStatus.ExpiresAt, 5*time.Second)
+ })
+
+ t.Run("attempt to set custom status when disabled", func(t *testing.T) {
+ th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.EnableCustomUserStatuses = false })
+ defer th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.EnableCustomUserStatuses = true })
+
+ toUpdateCustomStatus := &model.CustomStatus{
+ Emoji: "palm_tree",
+ Text: "My custom status",
+ }
+ _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus)
+ require.Error(t, err)
+ CheckNotImplementedStatus(t, resp)
+
+ // Assert that the error ID is "api.custom_status.disabled"
+ if appErr, ok := err.(*model.AppError); ok {
+ assert.Equal(t, "api.custom_status.disabled", appErr.Id)
+ } else {
+ t.Errorf("expected *model.AppError, got %T", err)
+ }
+ })
+
+ t.Run("attempt to set custom status for another user", func(t *testing.T) {
+ toUpdateCustomStatus := &model.CustomStatus{
+ Emoji: "palm_tree",
+ Text: "My custom status",
+ }
+ _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser2.Id, toUpdateCustomStatus)
+ require.Error(t, err)
+ CheckForbiddenStatus(t, resp)
+ })
+
+ t.Run("attempt to set custom status with invalid data", func(t *testing.T) {
+ toUpdateCustomStatus := &model.CustomStatus{
+ Emoji: "invalid_emoji",
+ Text: strings.Repeat("a", 101), // Exceeds max length
+ Duration: "invalid_duration",
+ ExpiresAt: time.Now().Add(-1 * time.Hour),
+ }
+ _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus)
+ require.Error(t, err)
+ CheckBadRequestStatus(t, resp)
+ })
+
+ t.Run("attempt to set custom status as non-authenticated user", func(t *testing.T) {
+ client.Logout(context.Background())
+ toUpdateCustomStatus := &model.CustomStatus{
+ Emoji: "palm_tree",
+ Text: "My custom status",
+ }
+ _, resp, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus)
+ require.Error(t, err)
+ CheckUnauthorizedStatus(t, resp)
+ })
+}
+
+func TestRemoveUserCustomStatus(t *testing.T) {
+ th := Setup(t).InitBasic()
+ defer th.TearDown()
+ client := th.Client
+
+ t.Run("remove custom status successfully", func(t *testing.T) {
+ toUpdateCustomStatus := &model.CustomStatus{
+ Emoji: "calendar",
+ Text: "My custom status",
+ }
+ _, _, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus)
+ require.NoError(t, err)
+
+ resp, err := client.RemoveUserCustomStatus(context.Background(), th.BasicUser.Id)
+ require.NoError(t, err)
+ CheckOKStatus(t, resp)
+
+ user, _, err := client.GetUser(context.Background(), th.BasicUser.Id, "")
+ require.NoError(t, err)
+ customStatus := user.GetCustomStatus()
+ assert.Nil(t, customStatus)
+ })
+
+ t.Run("attempt to remove custom status when disabled", func(t *testing.T) {
+ th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.EnableCustomUserStatuses = false })
+ defer th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.EnableCustomUserStatuses = true })
+
+ resp, err := client.RemoveUserCustomStatus(context.Background(), th.BasicUser.Id)
+ require.Error(t, err)
+ CheckNotImplementedStatus(t, resp)
+ })
+
+ t.Run("attempt to remove custom status for another user", func(t *testing.T) {
+ resp, err := client.RemoveUserCustomStatus(context.Background(), th.BasicUser2.Id)
+ require.Error(t, err)
+ CheckForbiddenStatus(t, resp)
+ })
+
+ t.Run("attempt to remove custom status as non-authenticated user", func(t *testing.T) {
+ client.Logout(context.Background())
+ resp, err := client.RemoveUserCustomStatus(context.Background(), th.BasicUser.Id)
+ require.Error(t, err)
+ CheckUnauthorizedStatus(t, resp)
+ })
+
+ t.Run("remove non-existent custom status", func(t *testing.T) {
+ th.LoginBasic()
+ resp, err := client.RemoveUserCustomStatus(context.Background(), th.BasicUser.Id)
+ require.NoError(t, err)
+ CheckOKStatus(t, resp)
+ })
+
+ t.Run("remove custom status with system admin", func(t *testing.T) {
+ toUpdateCustomStatus := &model.CustomStatus{
+ Emoji: "calendar",
+ Text: "My custom status",
+ }
+ _, _, err := client.UpdateUserCustomStatus(context.Background(), th.BasicUser.Id, toUpdateCustomStatus)
+ require.NoError(t, err)
+
+ resp, err := th.SystemAdminClient.RemoveUserCustomStatus(context.Background(), th.BasicUser.Id)
+ require.NoError(t, err)
+ CheckOKStatus(t, resp)
+
+ user, _, err := client.GetUser(context.Background(), th.BasicUser.Id, "")
+ require.NoError(t, err)
+ customStatus := user.GetCustomStatus()
+ assert.Nil(t, customStatus)
+ })
+}
diff --git a/server/channels/app/channel.go b/server/channels/app/channel.go
index 3a763e7633112..8664c33c2e2be 100644
--- a/server/channels/app/channel.go
+++ b/server/channels/app/channel.go
@@ -1298,7 +1298,7 @@ func (a *App) UpdateChannelMemberNotifyProps(c request.CTX, data map[string]stri
a.invalidateCacheForChannelMembersNotifyProps(member.ChannelId)
// Notify the clients that the member notify props changed
- err = a.sendUpdateChannelMemberNotifyPropsEvent(member)
+ err = a.sendUpdateChannelMemberEvent(member)
if err != nil {
return nil, model.NewAppError("UpdateChannelMemberNotifyProps", "api.marshal_error", nil, "", http.StatusInternalServerError).Wrap(err)
}
@@ -1339,7 +1339,7 @@ func (a *App) PatchChannelMembersNotifyProps(c request.CTX, members []*model.Cha
// Notify clients that their notify props have changed
for _, member := range updated {
- err := a.sendUpdateChannelMemberNotifyPropsEvent(member)
+ err := a.sendUpdateChannelMemberEvent(member)
if err != nil {
c.Logger().Warn("Failed to send WebSocket event for updated channel member notify props", mlog.Err(err))
}
@@ -1348,7 +1348,7 @@ func (a *App) PatchChannelMembersNotifyProps(c request.CTX, members []*model.Cha
return updated, nil
}
-func (a *App) sendUpdateChannelMemberNotifyPropsEvent(member *model.ChannelMember) error {
+func (a *App) sendUpdateChannelMemberEvent(member *model.ChannelMember) error {
evt := model.NewWebSocketEvent(model.WebsocketEventChannelMemberUpdated, "", "", member.UserId, nil, "")
memberJSON, jsonErr := json.Marshal(member)
if jsonErr != nil {
@@ -1962,9 +1962,11 @@ func (a *App) GetAllChannels(c request.CTX, page, perPage int, opts model.Channe
opts.ExcludeChannelNames = a.DefaultChannelNames(c)
}
storeOpts := store.ChannelSearchOpts{
- ExcludeChannelNames: opts.ExcludeChannelNames,
NotAssociatedToGroup: opts.NotAssociatedToGroup,
IncludeDeleted: opts.IncludeDeleted,
+ ExcludeChannelNames: opts.ExcludeChannelNames,
+ GroupConstrained: opts.GroupConstrained,
+ ExcludeGroupConstrained: opts.ExcludeGroupConstrained,
ExcludePolicyConstrained: opts.ExcludePolicyConstrained,
IncludePolicyID: opts.IncludePolicyID,
}
@@ -1981,9 +1983,13 @@ func (a *App) GetAllChannelsCount(c request.CTX, opts model.ChannelSearchOpts) (
opts.ExcludeChannelNames = a.DefaultChannelNames(c)
}
storeOpts := store.ChannelSearchOpts{
- ExcludeChannelNames: opts.ExcludeChannelNames,
- NotAssociatedToGroup: opts.NotAssociatedToGroup,
- IncludeDeleted: opts.IncludeDeleted,
+ NotAssociatedToGroup: opts.NotAssociatedToGroup,
+ IncludeDeleted: opts.IncludeDeleted,
+ ExcludeChannelNames: opts.ExcludeChannelNames,
+ GroupConstrained: opts.GroupConstrained,
+ ExcludeGroupConstrained: opts.ExcludeGroupConstrained,
+ ExcludePolicyConstrained: opts.ExcludePolicyConstrained,
+ IncludePolicyID: opts.IncludePolicyID,
}
count, err := a.Srv().Store().Channel().GetAllChannelsCount(storeOpts)
if err != nil {
diff --git a/server/channels/app/syncables.go b/server/channels/app/syncables.go
index 45d293cd9830e..d4c3cdeaa5ba5 100644
--- a/server/channels/app/syncables.go
+++ b/server/channels/app/syncables.go
@@ -234,26 +234,47 @@ func (a *App) SyncSyncableRoles(rctx request.CTX, syncableID string, syncableTyp
switch syncableType {
case model.GroupSyncableTypeTeam:
- nErr := a.Srv().Store().Team().UpdateMembersRole(syncableID, permittedAdmins)
- if nErr != nil {
- return model.NewAppError("App.SyncSyncableRoles", "app.update_error", nil, "", http.StatusInternalServerError).Wrap(nErr)
+ var updatedMembers []*model.TeamMember
+ updatedMembers, err = a.Srv().Store().Team().UpdateMembersRole(syncableID, permittedAdmins)
+ if err != nil {
+ return model.NewAppError("App.SyncSyncableRoles", "app.update_error", nil, "", http.StatusInternalServerError).Wrap(err)
+ }
+
+ for _, member := range updatedMembers {
+ a.ClearSessionCacheForUser(member.UserId)
+
+ if appErr := a.sendUpdatedTeamMemberEvent(member); appErr != nil {
+ rctx.Logger().Warn("Error sending channel member updated websocket event", mlog.Err(appErr))
+ }
}
- return nil
case model.GroupSyncableTypeChannel:
- nErr := a.Srv().Store().Channel().UpdateMembersRole(syncableID, permittedAdmins)
- if nErr != nil {
- return model.NewAppError("App.SyncSyncableRoles", "app.update_error", nil, "", http.StatusInternalServerError).Wrap(nErr)
+ var updatedMembers []*model.ChannelMember
+ updatedMembers, err = a.Srv().Store().Channel().UpdateMembersRole(syncableID, permittedAdmins)
+ if err != nil {
+ return model.NewAppError("App.SyncSyncableRoles", "app.update_error", nil, "", http.StatusInternalServerError).Wrap(err)
+ }
+
+ for _, member := range updatedMembers {
+ a.ClearSessionCacheForUser(member.UserId)
+
+ if appErr := a.sendUpdateChannelMemberEvent(member); appErr != nil {
+ rctx.Logger().Warn("Error sending channel member updated websocket event", mlog.Err(appErr))
+ }
}
- return nil
default:
return model.NewAppError("App.SyncSyncableRoles", "groups.unsupported_syncable_type", map[string]any{"Value": syncableType}, "", http.StatusInternalServerError)
}
+
+ return nil
}
// SyncRolesAndMembership updates the SchemeAdmin status and membership of all of the members of the given
// syncable.
func (a *App) SyncRolesAndMembership(rctx request.CTX, syncableID string, syncableType model.GroupSyncableType, includeRemovedMembers bool) {
- a.SyncSyncableRoles(rctx, syncableID, syncableType)
+ appErr := a.SyncSyncableRoles(rctx, syncableID, syncableType)
+ if appErr != nil {
+ rctx.Logger().Warn("Error syncing syncable roles", mlog.Err(appErr))
+ }
lastJob, _ := a.Srv().Store().Job().GetNewestJobByStatusAndType(model.JobStatusSuccess, model.JobTypeLdapSync)
var since int64
@@ -272,9 +293,6 @@ func (a *App) SyncRolesAndMembership(rctx request.CTX, syncableID string, syncab
if err := a.deleteGroupConstrainedTeamMemberships(rctx, &syncableID); err != nil {
rctx.Logger().Warn("Error deleting group constrained team memberships", mlog.Err(err))
}
- if err := a.ClearTeamMembersCache(syncableID); err != nil {
- rctx.Logger().Warn("Error clearing team members cache", mlog.Err(err))
- }
case model.GroupSyncableTypeChannel:
params.ScopedChannelID = &syncableID
if err := a.createDefaultChannelMemberships(rctx, params); err != nil {
@@ -283,8 +301,5 @@ func (a *App) SyncRolesAndMembership(rctx request.CTX, syncableID string, syncab
if err := a.deleteGroupConstrainedChannelMemberships(rctx, &syncableID); err != nil {
rctx.Logger().Warn("Error deleting group constrained team memberships", mlog.Err(err))
}
- if err := a.ClearChannelMembersCache(rctx, syncableID); err != nil {
- rctx.Logger().Warn("Error clearing channel members cache", mlog.Err(err))
- }
}
}
diff --git a/server/channels/app/team.go b/server/channels/app/team.go
index 415965fe6ec77..72ce428660f82 100644
--- a/server/channels/app/team.go
+++ b/server/channels/app/team.go
@@ -469,7 +469,7 @@ func (a *App) UpdateTeamMemberRoles(c request.CTX, teamID string, userID string,
a.ClearSessionCacheForUser(userID)
- if appErr := a.sendUpdatedMemberRoleEvent(userID, member); appErr != nil {
+ if appErr := a.sendUpdatedTeamMemberEvent(member); appErr != nil {
return nil, appErr
}
@@ -512,15 +512,15 @@ func (a *App) UpdateTeamMemberSchemeRoles(c request.CTX, teamID string, userID s
a.ClearSessionCacheForUser(userID)
- if appErr := a.sendUpdatedMemberRoleEvent(userID, member); appErr != nil {
+ if appErr := a.sendUpdatedTeamMemberEvent(member); appErr != nil {
return nil, appErr
}
return member, nil
}
-func (a *App) sendUpdatedMemberRoleEvent(userID string, member *model.TeamMember) *model.AppError {
- message := model.NewWebSocketEvent(model.WebsocketEventMemberroleUpdated, "", "", userID, nil, "")
+func (a *App) sendUpdatedTeamMemberEvent(member *model.TeamMember) *model.AppError {
+ message := model.NewWebSocketEvent(model.WebsocketEventMemberroleUpdated, "", "", member.UserId, nil, "")
tmJSON, jsonErr := json.Marshal(member)
if jsonErr != nil {
return model.NewAppError("sendUpdatedMemberRoleEvent", "api.marshal_error", nil, "", http.StatusInternalServerError).Wrap(jsonErr)
diff --git a/server/channels/app/user.go b/server/channels/app/user.go
index 62f9d0004e0ad..c17155e70094b 100644
--- a/server/channels/app/user.go
+++ b/server/channels/app/user.go
@@ -2443,7 +2443,7 @@ func (a *App) PromoteGuestToUser(c request.CTX, user *model.User, requestorId st
}
for _, member := range teamMembers {
- a.sendUpdatedMemberRoleEvent(user.Id, member)
+ a.sendUpdatedTeamMemberEvent(member)
channelMembers, appErr := a.GetChannelMembersForUser(c, member.TeamId, user.Id)
if appErr != nil {
@@ -2487,7 +2487,7 @@ func (a *App) DemoteUserToGuest(c request.CTX, user *model.User) *model.AppError
}
for _, member := range teamMembers {
- a.sendUpdatedMemberRoleEvent(user.Id, member)
+ a.sendUpdatedTeamMemberEvent(member)
channelMembers, appErr := a.GetChannelMembersForUser(c, member.TeamId, user.Id)
if appErr != nil {
diff --git a/server/channels/store/opentracinglayer/opentracinglayer.go b/server/channels/store/opentracinglayer/opentracinglayer.go
index ec5ff4791ee6f..e05d3ac0cbb86 100644
--- a/server/channels/store/opentracinglayer/opentracinglayer.go
+++ b/server/channels/store/opentracinglayer/opentracinglayer.go
@@ -2563,7 +2563,7 @@ func (s *OpenTracingLayerChannelStore) UpdateMemberNotifyProps(channelID string,
return result, err
}
-func (s *OpenTracingLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) error {
+func (s *OpenTracingLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) ([]*model.ChannelMember, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.UpdateMembersRole")
s.Root.Store.SetContext(newCtx)
@@ -2572,13 +2572,13 @@ func (s *OpenTracingLayerChannelStore) UpdateMembersRole(channelID string, userI
}()
defer span.Finish()
- err := s.ChannelStore.UpdateMembersRole(channelID, userIDs)
+ result, err := s.ChannelStore.UpdateMembersRole(channelID, userIDs)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)
}
- return err
+ return result, err
}
func (s *OpenTracingLayerChannelStore) UpdateMultipleMembers(members []*model.ChannelMember) ([]*model.ChannelMember, error) {
@@ -10554,7 +10554,7 @@ func (s *OpenTracingLayerTeamStore) UpdateMember(rctx request.CTX, member *model
return result, err
}
-func (s *OpenTracingLayerTeamStore) UpdateMembersRole(teamID string, userIDs []string) error {
+func (s *OpenTracingLayerTeamStore) UpdateMembersRole(teamID string, adminIDs []string) ([]*model.TeamMember, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "TeamStore.UpdateMembersRole")
s.Root.Store.SetContext(newCtx)
@@ -10563,13 +10563,13 @@ func (s *OpenTracingLayerTeamStore) UpdateMembersRole(teamID string, userIDs []s
}()
defer span.Finish()
- err := s.TeamStore.UpdateMembersRole(teamID, userIDs)
+ result, err := s.TeamStore.UpdateMembersRole(teamID, adminIDs)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)
}
- return err
+ return result, err
}
func (s *OpenTracingLayerTeamStore) UpdateMultipleMembers(members []*model.TeamMember) ([]*model.TeamMember, error) {
diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go
index e04b8c5a36e2b..3db4b2d39ef4a 100644
--- a/server/channels/store/retrylayer/retrylayer.go
+++ b/server/channels/store/retrylayer/retrylayer.go
@@ -2840,21 +2840,21 @@ func (s *RetryLayerChannelStore) UpdateMemberNotifyProps(channelID string, userI
}
-func (s *RetryLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) error {
+func (s *RetryLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) ([]*model.ChannelMember, error) {
tries := 0
for {
- err := s.ChannelStore.UpdateMembersRole(channelID, userIDs)
+ result, err := s.ChannelStore.UpdateMembersRole(channelID, userIDs)
if err == nil {
- return nil
+ return result, nil
}
if !isRepeatableError(err) {
- return err
+ return result, err
}
tries++
if tries >= 3 {
err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
- return err
+ return result, err
}
timepkg.Sleep(100 * timepkg.Millisecond)
}
@@ -12071,21 +12071,21 @@ func (s *RetryLayerTeamStore) UpdateMember(rctx request.CTX, member *model.TeamM
}
-func (s *RetryLayerTeamStore) UpdateMembersRole(teamID string, userIDs []string) error {
+func (s *RetryLayerTeamStore) UpdateMembersRole(teamID string, adminIDs []string) ([]*model.TeamMember, error) {
tries := 0
for {
- err := s.TeamStore.UpdateMembersRole(teamID, userIDs)
+ result, err := s.TeamStore.UpdateMembersRole(teamID, adminIDs)
if err == nil {
- return nil
+ return result, nil
}
if !isRepeatableError(err) {
- return err
+ return result, err
}
tries++
if tries >= 3 {
err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
- return err
+ return result, err
}
timepkg.Sleep(100 * timepkg.Millisecond)
}
diff --git a/server/channels/store/sqlstore/channel_store.go b/server/channels/store/sqlstore/channel_store.go
index dce8d6ddcfc11..ec81005e3836a 100644
--- a/server/channels/store/sqlstore/channel_store.go
+++ b/server/channels/store/sqlstore/channel_store.go
@@ -7,6 +7,7 @@ import (
"context"
"database/sql"
"fmt"
+ "slices"
"sort"
"strconv"
"strings"
@@ -1190,6 +1191,15 @@ func (s SqlChannelStore) getAllChannelsQuery(opts store.ChannelSearchOpts, forCo
query = query.Where("c.Id NOT IN (SELECT ChannelId FROM GroupChannels WHERE GroupChannels.GroupId = ? AND GroupChannels.DeleteAt = 0)", opts.NotAssociatedToGroup)
}
+ if opts.GroupConstrained {
+ query = query.Where(sq.Eq{"c.GroupConstrained": true})
+ } else if opts.ExcludeGroupConstrained {
+ query = query.Where(sq.Or{
+ sq.NotEq{"c.GroupConstrained": true},
+ sq.Eq{"c.GroupConstrained": nil},
+ })
+ }
+
if len(opts.ExcludeChannelNames) > 0 {
query = query.Where(sq.NotEq{"c.Name": opts.ExcludeChannelNames})
}
@@ -4161,27 +4171,76 @@ func (s SqlChannelStore) UserBelongsToChannels(userId string, channelIds []strin
return c > 0, nil
}
-// TODO: parameterize userIDs
-func (s SqlChannelStore) UpdateMembersRole(channelID string, userIDs []string) error {
- sql := fmt.Sprintf(`
- UPDATE
- ChannelMembers
- SET
- SchemeAdmin = CASE WHEN UserId IN ('%s') THEN
- TRUE
- ELSE
- FALSE
- END
- WHERE
- ChannelId = ?
- AND (SchemeGuest = false OR SchemeGuest IS NULL)
- `, strings.Join(userIDs, "', '"))
+// UpdateMembersRole updates all the members of channelID in the adminIDs string array to be admins and sets all other
+// users as not being admin.
+// It returns the list of userIDs whose roles got updated.
+//
+// TODO: parameterize adminIDs
+func (s SqlChannelStore) UpdateMembersRole(channelID string, adminIDs []string) (_ []*model.ChannelMember, err error) {
+ transaction, err := s.GetMasterX().Beginx()
+ if err != nil {
+ return nil, err
+ }
+ defer finalizeTransactionX(transaction, &err)
- if _, err := s.GetMasterX().Exec(sql, channelID); err != nil {
- return errors.Wrap(err, "failed to update ChannelMembers")
+ // On MySQL it's not possible to update a table and select from it in the same query.
+ // A SELECT and a UPDATE query are needed.
+ // Once we only support PostgreSQL, this can be done in a single query using RETURNING.
+ query, args, err := s.getQueryBuilder().
+ Select("*").
+ From("ChannelMembers").
+ Where(sq.Eq{"ChannelID": channelID}).
+ Where(sq.Or{sq.Eq{"SchemeGuest": false}, sq.Expr("SchemeGuest IS NULL")}).
+ Where(
+ sq.Or{
+ // New admins
+ sq.And{
+ sq.Eq{"SchemeAdmin": false},
+ sq.Eq{"UserId": adminIDs},
+ },
+ // Demoted admins
+ sq.And{
+ sq.Eq{"SchemeAdmin": true},
+ sq.NotEq{"UserId": adminIDs},
+ },
+ },
+ ).ToSql()
+ if err != nil {
+ return nil, errors.Wrap(err, "channel_tosql")
}
- return nil
+ var updatedMembers []*model.ChannelMember
+ if err = transaction.Select(&updatedMembers, query, args...); err != nil {
+ return nil, errors.Wrap(err, "failed to get list of updated users")
+ }
+
+ // Update SchemeAdmin field as the data from the SQL is not updated yet
+ for _, member := range updatedMembers {
+ if slices.Contains(adminIDs, member.UserId) {
+ member.SchemeAdmin = true
+ } else {
+ member.SchemeAdmin = false
+ }
+ }
+
+ query, args, err = s.getQueryBuilder().
+ Update("ChannelMembers").
+ Set("SchemeAdmin", sq.Case().When(sq.Eq{"UserId": adminIDs}, "true").Else("false")).
+ Where(sq.Eq{"ChannelId": channelID}).
+ Where(sq.Or{sq.Eq{"SchemeGuest": false}, sq.Expr("SchemeGuest IS NULL")}).ToSql()
+ if err != nil {
+ return nil, errors.Wrap(err, "team_tosql")
+ }
+
+ if _, err = transaction.Exec(query, args...); err != nil {
+ return nil, errors.Wrap(err, "failed to update ChannelMembers")
+ }
+
+ if err = transaction.Commit(); err != nil {
+ return nil, errors.Wrap(err, "commit_transaction")
+ }
+
+ return updatedMembers, nil
}
func (s SqlChannelStore) GroupSyncedChannelCount() (int64, error) {
diff --git a/server/channels/store/sqlstore/team_store.go b/server/channels/store/sqlstore/team_store.go
index 92674434c1cc3..f564fd1e62730 100644
--- a/server/channels/store/sqlstore/team_store.go
+++ b/server/channels/store/sqlstore/team_store.go
@@ -6,6 +6,7 @@ package sqlstore
import (
"database/sql"
"fmt"
+ "slices"
"strings"
sq "github.com/mattermost/squirrel"
@@ -1591,23 +1592,74 @@ func (s SqlTeamStore) UserBelongsToTeams(userId string, teamIds []string) (bool,
return c > 0, nil
}
-// UpdateMembersRole updates all the members of teamID in the userIds string array to be admins and sets all other
+// UpdateMembersRole updates all the members of teamID in the adminIDs string array to be admins and sets all other
// users as not being admin.
-func (s SqlTeamStore) UpdateMembersRole(teamID string, userIDs []string) error {
+// It returns the list of userIDs whose roles got updated.
+func (s SqlTeamStore) UpdateMembersRole(teamID string, adminIDs []string) (_ []*model.TeamMember, err error) {
+ transaction, err := s.GetMasterX().Beginx()
+ if err != nil {
+ return nil, err
+ }
+ defer finalizeTransactionX(transaction, &err)
+
+ // On MySQL it's not possible to update a table and select from it in the same query.
+ // A SELECT and a UPDATE query are needed.
+ // Once we only support PostgreSQL, this can be done in a single query using RETURNING.
query, args, err := s.getQueryBuilder().
+ Select("*").
+ From("TeamMembers").
+ Where(sq.Eq{"TeamId": teamID, "DeleteAt": 0}).
+ Where(sq.Or{sq.Eq{"SchemeGuest": false}, sq.Expr("SchemeGuest IS NULL")}).
+ Where(
+ sq.Or{
+ // New admins
+ sq.And{
+ sq.Eq{"SchemeAdmin": false},
+ sq.Eq{"UserId": adminIDs},
+ },
+ // Demoted admins
+ sq.And{
+ sq.Eq{"SchemeAdmin": true},
+ sq.NotEq{"UserId": adminIDs},
+ },
+ },
+ ).ToSql()
+ if err != nil {
+ return nil, errors.Wrap(err, "team_tosql")
+ }
+
+ var updatedMembers []*model.TeamMember
+ if err = transaction.Select(&updatedMembers, query, args...); err != nil {
+ return nil, errors.Wrap(err, "failed to get list of updated users")
+ }
+
+ // Update SchemeAdmin field as the data from the SQL is not updated yet
+ for _, member := range updatedMembers {
+ if slices.Contains(adminIDs, member.UserId) {
+ member.SchemeAdmin = true
+ } else {
+ member.SchemeAdmin = false
+ }
+ }
+
+ query, args, err = s.getQueryBuilder().
Update("TeamMembers").
- Set("SchemeAdmin", sq.Case().When(sq.Eq{"UserId": userIDs}, "true").Else("false")).
+ Set("SchemeAdmin", sq.Case().When(sq.Eq{"UserId": adminIDs}, "true").Else("false")).
Where(sq.Eq{"TeamId": teamID, "DeleteAt": 0}).
Where(sq.Or{sq.Eq{"SchemeGuest": false}, sq.Expr("SchemeGuest IS NULL")}).ToSql()
if err != nil {
- return errors.Wrap(err, "team_tosql")
+ return nil, errors.Wrap(err, "team_tosql")
}
- if _, err = s.GetMasterX().Exec(query, args...); err != nil {
- return errors.Wrap(err, "failed to update TeamMembers")
+ if _, err = transaction.Exec(query, args...); err != nil {
+ return nil, errors.Wrap(err, "failed to update TeamMembers")
}
- return nil
+ if err = transaction.Commit(); err != nil {
+ return nil, errors.Wrap(err, "commit_transaction")
+ }
+
+ return updatedMembers, nil
}
func applyTeamMemberViewRestrictionsFilter(query sq.SelectBuilder, restrictions *model.ViewUsersRestrictions) sq.SelectBuilder {
diff --git a/server/channels/store/store.go b/server/channels/store/store.go
index 443363efeb39a..0463e500fda18 100644
--- a/server/channels/store/store.go
+++ b/server/channels/store/store.go
@@ -168,7 +168,8 @@ type TeamStore interface {
// UpdateMembersRole sets all of the given team members to admins and all of the other members of the team to
// non-admin members.
- UpdateMembersRole(teamID string, userIDs []string) error
+ // It returns the list of userIDs whose roles got updated.
+ UpdateMembersRole(teamID string, adminIDs []string) ([]*model.TeamMember, error)
// GroupSyncedTeamCount returns the count of non-deleted group-constrained teams.
GroupSyncedTeamCount() (int64, error)
@@ -300,7 +301,8 @@ type ChannelStore interface {
// UpdateMembersRole sets all of the given team members to admins and all of the other members of the team to
// non-admin members.
- UpdateMembersRole(channelID string, userIDs []string) error
+ // It returns the list of userIDs whose roles got updated.
+ UpdateMembersRole(channelID string, userIDs []string) ([]*model.ChannelMember, error)
// GroupSyncedChannelCount returns the count of non-deleted group-constrained channels.
GroupSyncedChannelCount() (int64, error)
diff --git a/server/channels/store/storetest/channel_store.go b/server/channels/store/storetest/channel_store.go
index a9877e0eaab6b..5a226902adf37 100644
--- a/server/channels/store/storetest/channel_store.go
+++ b/server/channels/store/storetest/channel_store.go
@@ -3854,6 +3854,7 @@ func testChannelStoreGetAllChannels(t *testing.T, rctx request.CTX, ss store.Sto
c1.DisplayName = "Channel1" + model.NewId()
c1.Name = NewTestId()
c1.Type = model.ChannelTypeOpen
+ c1.GroupConstrained = model.NewPointer(true)
_, nErr := ss.Channel().Save(rctx, &c1, -1)
require.NoError(t, nErr)
@@ -3939,6 +3940,19 @@ func testChannelStoreGetAllChannels(t *testing.T, rctx request.CTX, ss store.Sto
list, nErr = ss.Channel().GetAllChannels(0, 10, store.ChannelSearchOpts{NotAssociatedToGroup: group.Id})
require.NoError(t, nErr)
assert.Len(t, list, 1)
+ assert.Equal(t, c3.Id, list[0].Id)
+
+ // GroupConstrained
+ list, nErr = ss.Channel().GetAllChannels(0, 10, store.ChannelSearchOpts{GroupConstrained: true})
+ require.NoError(t, nErr)
+ require.Len(t, list, 1)
+ assert.Equal(t, c1.Id, list[0].Id)
+
+ // ExcludeGroupConstrained
+ list, nErr = ss.Channel().GetAllChannels(0, 10, store.ChannelSearchOpts{ExcludeGroupConstrained: true})
+ require.NoError(t, nErr)
+ require.Len(t, list, 1)
+ assert.Equal(t, c3.Id, list[0].Id)
// Exclude channel names
list, nErr = ss.Channel().GetAllChannels(0, 10, store.ChannelSearchOpts{ExcludeChannelNames: []string{c1.Name}})
diff --git a/server/channels/store/storetest/group_store.go b/server/channels/store/storetest/group_store.go
index 8bb0d465f015d..3e8a2f671220d 100644
--- a/server/channels/store/storetest/group_store.go
+++ b/server/channels/store/storetest/group_store.go
@@ -81,7 +81,7 @@ func TestGroupStore(t *testing.T, rctx request.CTX, ss store.Store) {
t.Run("AdminRoleGroupsForSyncableMember_Team", func(t *testing.T) { groupTestAdminRoleGroupsForSyncableMemberTeam(t, rctx, ss) })
t.Run("PermittedSyncableAdmins_Team", func(t *testing.T) { groupTestPermittedSyncableAdminsTeam(t, rctx, ss) })
t.Run("PermittedSyncableAdmins_Channel", func(t *testing.T) { groupTestPermittedSyncableAdminsChannel(t, rctx, ss) })
- t.Run("UpdateMembersRole_Team", func(t *testing.T) { groupTestpUpdateMembersRoleTeam(t, rctx, ss) })
+ t.Run("UpdateMembersRole_Team", func(t *testing.T) { groupTestUpdateMembersRoleTeam(t, rctx, ss) })
t.Run("UpdateMembersRole_Channel", func(t *testing.T) { groupTestpUpdateMembersRoleChannel(t, rctx, ss) })
t.Run("GroupCount", func(t *testing.T) { groupTestGroupCount(t, rctx, ss) })
@@ -4739,7 +4739,7 @@ func groupTestPermittedSyncableAdminsChannel(t *testing.T, rctx request.CTX, ss
require.ElementsMatch(t, []string{user3.Id}, actualUserIDs)
}
-func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.Store) {
+func groupTestUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.Store) {
team := &model.Team{
DisplayName: "Name",
Description: "Some description",
@@ -4759,6 +4759,7 @@ func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.St
}
user1, err = ss.User().Save(rctx, user1)
require.NoError(t, err)
+ t.Log("Created user1", user1.Id)
user2 := &model.User{
Email: MakeEmail(),
@@ -4766,6 +4767,7 @@ func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.St
}
user2, err = ss.User().Save(rctx, user2)
require.NoError(t, err)
+ t.Log("Created user2", user2.Id)
user3 := &model.User{
Email: MakeEmail(),
@@ -4773,6 +4775,7 @@ func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.St
}
user3, err = ss.User().Save(rctx, user3)
require.NoError(t, err)
+ t.Log("Created user3", user3.Id)
user4 := &model.User{
Email: MakeEmail(),
@@ -4780,6 +4783,7 @@ func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.St
}
user4, err = ss.User().Save(rctx, user4)
require.NoError(t, err)
+ t.Log("Created user4", user4.Id)
for _, user := range []*model.User{user1, user2, user3} {
_, nErr := ss.Team().SaveMember(rctx, &model.TeamMember{TeamId: team.Id, UserId: user.Id}, 9999)
@@ -4790,53 +4794,73 @@ func groupTestpUpdateMembersRoleTeam(t *testing.T, rctx request.CTX, ss store.St
require.NoError(t, nErr)
tests := []struct {
- testName string
- inUserIDs []string
- targetSchemeAdminValue bool
+ testName string
+ newAdmins []string
+ expectedUpdatedUsers []string
}{
{
- "Given users are admins",
+ "Two new admins",
+ []string{user1.Id, user2.Id},
[]string{user1.Id, user2.Id},
- true,
},
{
- "Given users are members",
+ "Demote one admin",
+ []string{user1.Id},
[]string{user2.Id},
- false,
},
{
- "Non-given users are admins",
- []string{user2.Id},
- false,
+ "Operation is idempotent",
+ []string{user1.Id},
+ nil,
},
{
- "Non-given users are members",
- []string{user2.Id},
- false,
+ "Promote a team member",
+ []string{user1.Id, user3.Id},
+ []string{user3.Id},
+ },
+ {
+ "Guests never get promoted",
+ []string{user1.Id, user3.Id, user4.Id},
+ nil,
},
}
for _, tt := range tests {
t.Run(tt.testName, func(t *testing.T) {
- err = ss.Team().UpdateMembersRole(team.Id, tt.inUserIDs)
+ var updatedMembers []*model.TeamMember
+ updatedMembers, err = ss.Team().UpdateMembersRole(team.Id, tt.newAdmins)
require.NoError(t, err)
- members, err := ss.Team().GetMembers(team.Id, 0, 100, nil)
- require.NoError(t, err)
- require.GreaterOrEqual(t, len(members), 4) // sanity check for team membership
+ var updatedUserIDs []string
+ for _, member := range updatedMembers {
+ assert.False(t, member.SchemeGuest, fmt.Sprintf("userID: %s", member.UserId))
- for _, member := range members {
- if slices.Contains(tt.inUserIDs, member.UserId) {
- require.True(t, member.SchemeAdmin)
+ if slices.Contains(tt.newAdmins, member.UserId) {
+ assert.True(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId))
} else {
- require.False(t, member.SchemeAdmin)
+ assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId))
}
+ updatedUserIDs = append(updatedUserIDs, member.UserId)
+ }
+ assert.ElementsMatch(t, tt.expectedUpdatedUsers, updatedUserIDs)
+
+ members, err := ss.Team().GetMembers(team.Id, 0, 100, nil)
+ require.NoError(t, err)
+ assert.GreaterOrEqual(t, len(members), 4) // sanity check for team membership
+
+ for _, member := range members {
// Ensure guest account never changes.
if member.UserId == user4.Id {
- require.False(t, member.SchemeUser)
- require.False(t, member.SchemeAdmin)
- require.True(t, member.SchemeGuest)
+ assert.False(t, member.SchemeUser, fmt.Sprintf("userID: %s", member.UserId))
+ assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId))
+ assert.True(t, member.SchemeGuest, fmt.Sprintf("userID: %s", member.UserId))
+ } else {
+ if slices.Contains(tt.newAdmins, member.UserId) {
+ assert.True(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId))
+ } else {
+ assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId))
+ }
}
}
})
@@ -4859,6 +4883,7 @@ func groupTestpUpdateMembersRoleChannel(t *testing.T, rctx request.CTX, ss store
}
user1, err = ss.User().Save(rctx, user1)
require.NoError(t, err)
+ t.Log("Created user1", user1.Id)
user2 := &model.User{
Email: MakeEmail(),
@@ -4866,6 +4891,7 @@ func groupTestpUpdateMembersRoleChannel(t *testing.T, rctx request.CTX, ss store
}
user2, err = ss.User().Save(rctx, user2)
require.NoError(t, err)
+ t.Log("Created user2", user2.Id)
user3 := &model.User{
Email: MakeEmail(),
@@ -4873,6 +4899,7 @@ func groupTestpUpdateMembersRoleChannel(t *testing.T, rctx request.CTX, ss store
}
user3, err = ss.User().Save(rctx, user3)
require.NoError(t, err)
+ t.Log("Created user3", user3.Id)
user4 := &model.User{
Email: MakeEmail(),
@@ -4880,6 +4907,7 @@ func groupTestpUpdateMembersRoleChannel(t *testing.T, rctx request.CTX, ss store
}
user4, err = ss.User().Save(rctx, user4)
require.NoError(t, err)
+ t.Log("Created user4", user4.Id)
for _, user := range []*model.User{user1, user2, user3} {
_, err = ss.Channel().SaveMember(rctx, &model.ChannelMember{
@@ -4899,54 +4927,73 @@ func groupTestpUpdateMembersRoleChannel(t *testing.T, rctx request.CTX, ss store
require.NoError(t, err)
tests := []struct {
- testName string
- inUserIDs []string
- targetSchemeAdminValue bool
+ testName string
+ newAdmins []string
+ expectedUpdatedUsers []string
}{
{
- "Given users are admins",
+ "Two new admins",
+ []string{user1.Id, user2.Id},
[]string{user1.Id, user2.Id},
- true,
},
{
- "Given users are members",
+ "Demote one admin",
+ []string{user1.Id},
[]string{user2.Id},
- false,
},
{
- "Non-given users are admins",
- []string{user2.Id},
- false,
+ "Operation is idempotent",
+ []string{user1.Id},
+ nil,
},
{
- "Non-given users are members",
- []string{user2.Id},
- false,
+ "Promote a team member",
+ []string{user1.Id, user3.Id},
+ []string{user3.Id},
+ },
+ {
+ "Guests never get promoted",
+ []string{user1.Id, user3.Id, user4.Id},
+ nil,
},
}
for _, tt := range tests {
t.Run(tt.testName, func(t *testing.T) {
- err = ss.Channel().UpdateMembersRole(channel.Id, tt.inUserIDs)
+ var updatedMemmbers []*model.ChannelMember
+ updatedMemmbers, err = ss.Channel().UpdateMembersRole(channel.Id, tt.newAdmins)
require.NoError(t, err)
- members, err := ss.Channel().GetMembers(channel.Id, 0, 100)
- require.NoError(t, err)
+ var updatedUserIDs []string
+ for _, member := range updatedMemmbers {
+ assert.False(t, member.SchemeGuest, fmt.Sprintf("userID: %s", member.UserId))
- require.GreaterOrEqual(t, len(members), 4) // sanity check for channel membership
-
- for _, member := range members {
- if slices.Contains(tt.inUserIDs, member.UserId) {
- require.True(t, member.SchemeAdmin)
+ if slices.Contains(tt.newAdmins, member.UserId) {
+ assert.True(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId))
} else {
- require.False(t, member.SchemeAdmin)
+ assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId))
}
+ updatedUserIDs = append(updatedUserIDs, member.UserId)
+ }
+ assert.ElementsMatch(t, tt.expectedUpdatedUsers, updatedUserIDs)
+
+ members, err := ss.Channel().GetMembers(channel.Id, 0, 100)
+ require.NoError(t, err)
+ assert.GreaterOrEqual(t, len(members), 4) // sanity check for channel membership
+
+ for _, member := range members {
// Ensure guest account never changes.
if member.UserId == user4.Id {
- require.False(t, member.SchemeUser)
- require.False(t, member.SchemeAdmin)
- require.True(t, member.SchemeGuest)
+ assert.False(t, member.SchemeUser, fmt.Sprintf("userID: %s", member.UserId))
+ assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId))
+ assert.True(t, member.SchemeGuest, fmt.Sprintf("userID: %s", member.UserId))
+ } else {
+ if slices.Contains(tt.newAdmins, member.UserId) {
+ assert.True(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId))
+ } else {
+ assert.False(t, member.SchemeAdmin, fmt.Sprintf("userID: %s", member.UserId))
+ }
}
}
})
diff --git a/server/channels/store/storetest/mocks/ChannelStore.go b/server/channels/store/storetest/mocks/ChannelStore.go
index 2661df49ce81a..489151e9cbe75 100644
--- a/server/channels/store/storetest/mocks/ChannelStore.go
+++ b/server/channels/store/storetest/mocks/ChannelStore.go
@@ -2915,21 +2915,33 @@ func (_m *ChannelStore) UpdateMemberNotifyProps(channelID string, userID string,
}
// UpdateMembersRole provides a mock function with given fields: channelID, userIDs
-func (_m *ChannelStore) UpdateMembersRole(channelID string, userIDs []string) error {
+func (_m *ChannelStore) UpdateMembersRole(channelID string, userIDs []string) ([]*model.ChannelMember, error) {
ret := _m.Called(channelID, userIDs)
if len(ret) == 0 {
panic("no return value specified for UpdateMembersRole")
}
- var r0 error
- if rf, ok := ret.Get(0).(func(string, []string) error); ok {
+ var r0 []*model.ChannelMember
+ var r1 error
+ if rf, ok := ret.Get(0).(func(string, []string) ([]*model.ChannelMember, error)); ok {
+ return rf(channelID, userIDs)
+ }
+ if rf, ok := ret.Get(0).(func(string, []string) []*model.ChannelMember); ok {
r0 = rf(channelID, userIDs)
} else {
- r0 = ret.Error(0)
+ if ret.Get(0) != nil {
+ r0 = ret.Get(0).([]*model.ChannelMember)
+ }
}
- return r0
+ if rf, ok := ret.Get(1).(func(string, []string) error); ok {
+ r1 = rf(channelID, userIDs)
+ } else {
+ r1 = ret.Error(1)
+ }
+
+ return r0, r1
}
// UpdateMultipleMembers provides a mock function with given fields: members
diff --git a/server/channels/store/storetest/mocks/TeamStore.go b/server/channels/store/storetest/mocks/TeamStore.go
index 2c23ba3077fab..60493056c62a7 100644
--- a/server/channels/store/storetest/mocks/TeamStore.go
+++ b/server/channels/store/storetest/mocks/TeamStore.go
@@ -1336,22 +1336,34 @@ func (_m *TeamStore) UpdateMember(rctx request.CTX, member *model.TeamMember) (*
return r0, r1
}
-// UpdateMembersRole provides a mock function with given fields: teamID, userIDs
-func (_m *TeamStore) UpdateMembersRole(teamID string, userIDs []string) error {
- ret := _m.Called(teamID, userIDs)
+// UpdateMembersRole provides a mock function with given fields: teamID, adminIDs
+func (_m *TeamStore) UpdateMembersRole(teamID string, adminIDs []string) ([]*model.TeamMember, error) {
+ ret := _m.Called(teamID, adminIDs)
if len(ret) == 0 {
panic("no return value specified for UpdateMembersRole")
}
- var r0 error
- if rf, ok := ret.Get(0).(func(string, []string) error); ok {
- r0 = rf(teamID, userIDs)
+ var r0 []*model.TeamMember
+ var r1 error
+ if rf, ok := ret.Get(0).(func(string, []string) ([]*model.TeamMember, error)); ok {
+ return rf(teamID, adminIDs)
+ }
+ if rf, ok := ret.Get(0).(func(string, []string) []*model.TeamMember); ok {
+ r0 = rf(teamID, adminIDs)
} else {
- r0 = ret.Error(0)
+ if ret.Get(0) != nil {
+ r0 = ret.Get(0).([]*model.TeamMember)
+ }
}
- return r0
+ if rf, ok := ret.Get(1).(func(string, []string) error); ok {
+ r1 = rf(teamID, adminIDs)
+ } else {
+ r1 = ret.Error(1)
+ }
+
+ return r0, r1
}
// UpdateMultipleMembers provides a mock function with given fields: members
diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go
index ef4bf76962019..61dde64b5bc67 100644
--- a/server/channels/store/timerlayer/timerlayer.go
+++ b/server/channels/store/timerlayer/timerlayer.go
@@ -2366,10 +2366,10 @@ func (s *TimerLayerChannelStore) UpdateMemberNotifyProps(channelID string, userI
return result, err
}
-func (s *TimerLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) error {
+func (s *TimerLayerChannelStore) UpdateMembersRole(channelID string, userIDs []string) ([]*model.ChannelMember, error) {
start := time.Now()
- err := s.ChannelStore.UpdateMembersRole(channelID, userIDs)
+ result, err := s.ChannelStore.UpdateMembersRole(channelID, userIDs)
elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil {
@@ -2379,7 +2379,7 @@ func (s *TimerLayerChannelStore) UpdateMembersRole(channelID string, userIDs []s
}
s.Root.Metrics.ObserveStoreMethodDuration("ChannelStore.UpdateMembersRole", success, elapsed)
}
- return err
+ return result, err
}
func (s *TimerLayerChannelStore) UpdateMultipleMembers(members []*model.ChannelMember) ([]*model.ChannelMember, error) {
@@ -9495,10 +9495,10 @@ func (s *TimerLayerTeamStore) UpdateMember(rctx request.CTX, member *model.TeamM
return result, err
}
-func (s *TimerLayerTeamStore) UpdateMembersRole(teamID string, userIDs []string) error {
+func (s *TimerLayerTeamStore) UpdateMembersRole(teamID string, adminIDs []string) ([]*model.TeamMember, error) {
start := time.Now()
- err := s.TeamStore.UpdateMembersRole(teamID, userIDs)
+ result, err := s.TeamStore.UpdateMembersRole(teamID, adminIDs)
elapsed := float64(time.Since(start)) / float64(time.Second)
if s.Root.Metrics != nil {
@@ -9508,7 +9508,7 @@ func (s *TimerLayerTeamStore) UpdateMembersRole(teamID string, userIDs []string)
}
s.Root.Metrics.ObserveStoreMethodDuration("TeamStore.UpdateMembersRole", success, elapsed)
}
- return err
+ return result, err
}
func (s *TimerLayerTeamStore) UpdateMultipleMembers(members []*model.TeamMember) ([]*model.TeamMember, error) {
diff --git a/webapp/channels/src/components/channel_select/index.ts b/webapp/channels/src/components/channel_select/index.ts
index 08d71804dd207..faac736179ade 100644
--- a/webapp/channels/src/components/channel_select/index.ts
+++ b/webapp/channels/src/components/channel_select/index.ts
@@ -17,7 +17,8 @@ const getMyChannelsSorted = createSelector(
getMyChannels,
getCurrentUserLocale,
(channels, locale) => {
- return [...channels].sort(sortChannelsByTypeAndDisplayName.bind(null, locale));
+ const activeChannels = channels.filter((channel) => channel.delete_at === 0);
+ return [...activeChannels].sort(sortChannelsByTypeAndDisplayName.bind(null, locale));
},
);
diff --git a/webapp/channels/src/components/root/root_redirect/index.ts b/webapp/channels/src/components/root/root_redirect/index.ts
index d361a82c98ec2..1309717653c09 100644
--- a/webapp/channels/src/components/root/root_redirect/index.ts
+++ b/webapp/channels/src/components/root/root_redirect/index.ts
@@ -7,6 +7,7 @@ import type {Dispatch} from 'redux';
import {getFirstAdminSetupComplete} from 'mattermost-redux/actions/general';
import {getIsOnboardingFlowEnabled} from 'mattermost-redux/selectors/entities/preferences';
+import {getActiveTeamsList} from 'mattermost-redux/selectors/entities/teams';
import {getCurrentUserId, isCurrentUserSystemAdmin, isFirstAdmin} from 'mattermost-redux/selectors/entities/users';
import type {GlobalState} from 'types/store';
@@ -15,6 +16,7 @@ import RootRedirect from './root_redirect';
function mapStateToProps(state: GlobalState) {
const onboardingFlowEnabled = getIsOnboardingFlowEnabled(state);
+ const teams = getActiveTeamsList(state);
let isElegibleForFirstAdmingOnboarding = onboardingFlowEnabled;
if (isElegibleForFirstAdmingOnboarding) {
isElegibleForFirstAdmingOnboarding = isCurrentUserSystemAdmin(state);
@@ -23,6 +25,7 @@ function mapStateToProps(state: GlobalState) {
currentUserId: getCurrentUserId(state),
isElegibleForFirstAdmingOnboarding,
isFirstAdmin: isFirstAdmin(state),
+ areThereTeams: Boolean(teams.length),
};
}
diff --git a/webapp/channels/src/components/root/root_redirect/root_redirect.test.tsx b/webapp/channels/src/components/root/root_redirect/root_redirect.test.tsx
new file mode 100644
index 0000000000000..8bb5278ec9647
--- /dev/null
+++ b/webapp/channels/src/components/root/root_redirect/root_redirect.test.tsx
@@ -0,0 +1,173 @@
+// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
+// See LICENSE.txt for license information.
+
+import {createMemoryHistory} from 'history';
+import React from 'react';
+import type {RouteComponentProps} from 'react-router-dom';
+import {Redirect} from 'react-router-dom';
+
+import {getFirstAdminSetupComplete as getFirstAdminSetupCompleteAction} from 'mattermost-redux/actions/general';
+
+import * as GlobalActions from 'actions/global_actions';
+
+import {renderWithContext, waitFor} from 'tests/react_testing_utils';
+
+import RootRedirect from './root_redirect';
+import type {Props} from './root_redirect';
+
+jest.mock('actions/global_actions', () => ({
+ redirectUserToDefaultTeam: jest.fn(),
+}));
+
+jest.mock('mattermost-redux/actions/general', () => ({
+ getFirstAdminSetupComplete: jest.fn(() =>
+ Promise.resolve({
+ data: true,
+ }),
+ ),
+}));
+
+jest.mock('react-router-dom', () => {
+ const actual = jest.requireActual('react-router-dom');
+ return {
+ ...actual,
+ Redirect: jest.fn(() => null),
+ };
+});
+
+describe('components/RootRedirect', () => {
+ const baseProps: Props = {
+ currentUserId: '',
+ isElegibleForFirstAdmingOnboarding: false,
+ isFirstAdmin: false,
+ areThereTeams: false,
+ actions: {
+ getFirstAdminSetupComplete: getFirstAdminSetupCompleteAction as jest.Mock,
+ },
+ };
+
+ const defaultProps = {
+ ...baseProps,
+ location: {
+ pathname: '/',
+ },
+ } as Props & RouteComponentProps;
+
+ afterEach(() => {
+ jest.clearAllMocks();
+ });
+
+ test('should redirect to /login when currentUserId is empty', () => {
+ renderWithContext();
+
+ expect(Redirect).toHaveBeenCalledTimes(1);
+ expect(Redirect).toHaveBeenCalledWith(
+ expect.objectContaining({
+ to: expect.objectContaining({
+ pathname: '/login',
+ }),
+ }),
+ {},
+ );
+ });
+
+ test('should call GlobalActions.redirectUserToDefaultTeam when user is logged in and not eligible for first admin onboarding', () => {
+ const props = {
+ ...defaultProps,
+ currentUserId: 'test-user-id',
+ isElegibleForFirstAdmingOnboarding: false,
+ };
+
+ renderWithContext();
+
+ expect(GlobalActions.redirectUserToDefaultTeam).toHaveBeenCalledTimes(1);
+ });
+
+ test('should redirect to preparing-workspace when eligible for first admin onboarding and no teams created', async () => {
+ const history = createMemoryHistory({initialEntries: ['/']});
+ const mockHistoryPush = jest.spyOn(history, 'push');
+
+ const props = {
+ currentUserId: 'test-user-id',
+ isElegibleForFirstAdmingOnboarding: true,
+ isFirstAdmin: true,
+ areThereTeams: false,
+ actions: {
+ getFirstAdminSetupComplete: jest.fn().mockResolvedValue({data: false}),
+ },
+ };
+
+ renderWithContext(, {}, {history});
+
+ expect(props.actions.getFirstAdminSetupComplete).toHaveBeenCalledTimes(1);
+
+ await waitFor(() => {
+ expect(mockHistoryPush).toHaveBeenCalledWith('/preparing-workspace');
+ });
+ });
+
+ test('should NOT redirect to preparing-workspace when there are teams created, even if system value for first admin onboarding complete is false', async () => {
+ const history = createMemoryHistory({initialEntries: ['/']});
+
+ const props = {
+ ...defaultProps,
+ currentUserId: 'test-user-id',
+ isElegibleForFirstAdmingOnboarding: true,
+ isFirstAdmin: true,
+ areThereTeams: true,
+ actions: {
+ getFirstAdminSetupComplete: jest.fn().mockResolvedValue({data: false}),
+ },
+ };
+
+ renderWithContext(, {}, {history});
+
+ expect(props.actions.getFirstAdminSetupComplete).toHaveBeenCalledTimes(1);
+
+ await waitFor(() => {
+ expect(GlobalActions.redirectUserToDefaultTeam).toHaveBeenCalledTimes(1);
+ });
+ });
+
+ test('should redirect to default team when first admin setup is complete', async () => {
+ const props = {
+ ...defaultProps,
+ currentUserId: 'test-user-id',
+ isElegibleForFirstAdmingOnboarding: true,
+ isFirstAdmin: true,
+ areThereTeams: false,
+ actions: {
+ getFirstAdminSetupComplete: jest.fn().mockResolvedValue({data: true}),
+ },
+ };
+
+ renderWithContext();
+
+ expect(props.actions.getFirstAdminSetupComplete).toHaveBeenCalledTimes(1);
+
+ await waitFor(() => {
+ expect(GlobalActions.redirectUserToDefaultTeam).toHaveBeenCalledTimes(1);
+ });
+ });
+
+ test('should redirect to default team when not first admin or teams exist', async () => {
+ const props = {
+ ...defaultProps,
+ currentUserId: 'test-user-id',
+ isElegibleForFirstAdmingOnboarding: true,
+ isFirstAdmin: false,
+ areThereTeams: true,
+ actions: {
+ getFirstAdminSetupComplete: jest.fn().mockResolvedValue({data: false}),
+ },
+ };
+
+ renderWithContext();
+
+ expect(props.actions.getFirstAdminSetupComplete).toHaveBeenCalledTimes(1);
+
+ await waitFor(() => {
+ expect(GlobalActions.redirectUserToDefaultTeam).toHaveBeenCalledTimes(1);
+ });
+ });
+});
diff --git a/webapp/channels/src/components/root/root_redirect/root_redirect.tsx b/webapp/channels/src/components/root/root_redirect/root_redirect.tsx
index c528a91c3aeb9..3a1db8624e92c 100644
--- a/webapp/channels/src/components/root/root_redirect/root_redirect.tsx
+++ b/webapp/channels/src/components/root/root_redirect/root_redirect.tsx
@@ -13,6 +13,7 @@ export type Props = {
currentUserId: string;
location?: Location;
isFirstAdmin: boolean;
+ areThereTeams: boolean;
actions: {
getFirstAdminSetupComplete: () => Promise;
};
@@ -26,7 +27,7 @@ export default function RootRedirect(props: Props) {
if (props.isElegibleForFirstAdmingOnboarding) {
props.actions.getFirstAdminSetupComplete().then((firstAdminCompletedSignup) => {
// root.tsx ensures admin profiles are eventually loaded
- if (firstAdminCompletedSignup.data === false && props.isFirstAdmin) {
+ if (firstAdminCompletedSignup.data === false && props.isFirstAdmin && !props.areThereTeams) {
history.push('/preparing-workspace');
} else {
GlobalActions.redirectUserToDefaultTeam();