Skip to content

Commit

Permalink
MM-57507 Improve OAuth Flow (#593)
Browse files Browse the repository at this point in the history
* update OAuth flow

* update for new flow

* MM-5701: reenable reattached plugin tests (#563)

* update OAuth flow

* update connection url

* fix bad merge

* remove commented code

* fix tests, cleanup

* fix test

* lint fixes

* Update command_test.go

* update to welcome message, add message when changing primary

* update test

* update more tests

* revert test and fix another

* Revert "update more tests"

This reverts commit 156d4a9.

* Revert "update test"

This reverts commit 231402e.

* Reimplement test changesest""

This reverts commit b07061c.

* another test fix

* ensure teams user id unique

---------

Co-authored-by: Jesse Hallam <[email protected]>
Co-authored-by: Miguel de la Cruz <[email protected]>
  • Loading branch information
3 people authored Apr 17, 2024
1 parent ff19941 commit f58b482
Show file tree
Hide file tree
Showing 13 changed files with 295 additions and 129 deletions.
77 changes: 50 additions & 27 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ const (
QueryParamPage = "page"
QueryParamPerPage = "per_page"
QueryParamPrimaryPlatform = "primary_platform"
QueryParamChannelID = "channel_id"
QueryParamPostID = "post_id"

APIChoosePrimaryPlatform = "/choose-primary-platform"
)
Expand Down Expand Up @@ -84,6 +86,7 @@ func NewAPI(p *Plugin, store store.Store) *API {
router.HandleFunc("/notify-connect", api.notifyConnect).Methods("GET")
router.HandleFunc(APIChoosePrimaryPlatform, api.choosePrimaryPlatform).Methods(http.MethodGet)
router.HandleFunc("/stats/site", api.siteStats).Methods("GET")
router.HandleFunc("/primary-platform", api.primaryPlatform).Methods("GET")

// iFrame support
router.HandleFunc("/iframe/mattermostTab", api.iFrame).Methods("GET")
Expand Down Expand Up @@ -332,8 +335,9 @@ func (a *API) connect(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
return
}
query := r.URL.Query()
userID := r.Header.Get("Mattermost-User-ID")
connectBot := r.URL.Query().Has("isBot")
connectBot := query.Has("isBot")
if connectBot {
if !a.p.API.HasPermissionTo(userID, model.PermissionManageSystem) {
a.p.API.LogWarn("Attempt to connect the bot account, by non system admin.", "user_id", userID)
Expand All @@ -343,13 +347,20 @@ func (a *API) connect(w http.ResponseWriter, r *http.Request) {
userID = a.p.GetBotUserID()
}

channelID := query.Get(QueryParamChannelID)
postID := query.Get(QueryParamPostID)
if channelID == "" || postID == "" {
a.p.API.LogWarn("Missing channelID or postID from query paramaeters", "channelID", channelID, "postID", postID)
http.Error(w, "Missing required query parameters.", http.StatusBadRequest)
}

if storedToken, _ := a.p.store.GetTokenForMattermostUser(userID); storedToken != nil {
a.p.API.LogWarn("The account is already connected to MS Teams", "user_id", userID)
http.Error(w, "Error in trying to connect the account, please try again.", http.StatusInternalServerError)
return
}

state := fmt.Sprintf("%s_%s", model.NewId(), userID)
state := fmt.Sprintf("%s_%s_%s_%s", model.NewId(), userID, postID, channelID)
if err := a.store.StoreOAuth2State(state); err != nil {
a.p.API.LogWarn("Error in storing the OAuth state", "error", err.Error())
http.Error(w, "Error in trying to connect the account, please try again.", http.StatusInternalServerError)
Expand All @@ -368,6 +379,34 @@ func (a *API) connect(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, connectURL, http.StatusSeeOther)
}

func (a *API) primaryPlatform(w http.ResponseWriter, r *http.Request) {
bundlePath, err := a.p.API.GetBundlePath()
if err != nil {
a.p.API.LogWarn("Failed to get bundle path.", "error", err.Error())
return
}

t, err := template.ParseFiles(filepath.Join(bundlePath, "assets/info-page/index.html"))
if err != nil {
a.p.API.LogError("unable to parse the template", "error", err.Error())
http.Error(w, "unable to view the primary platform selection page", http.StatusInternalServerError)
}

err = t.Execute(w, struct {
ServerURL string
APIEndPoint string
QueryParamPrimaryPlatform string
}{
ServerURL: a.p.GetURL(),
APIEndPoint: APIChoosePrimaryPlatform,
QueryParamPrimaryPlatform: QueryParamPrimaryPlatform,
})
if err != nil {
a.p.API.LogError("unable to execute the template", "error", err.Error())
http.Error(w, "unable to view the primary platform selection page", http.StatusInternalServerError)
}
}

func (a *API) notifyConnect(w http.ResponseWriter, r *http.Request) {
userID := r.Header.Get("Mattermost-User-ID")

Expand Down Expand Up @@ -405,7 +444,7 @@ func (a *API) oauthRedirectHandler(w http.ResponseWriter, r *http.Request) {
state := r.URL.Query().Get("state")

stateArr := strings.Split(state, "_")
if len(stateArr) != 2 {
if len(stateArr) != 4 {
http.Error(w, "Invalid state", http.StatusBadRequest)
return
}
Expand Down Expand Up @@ -524,31 +563,15 @@ func (a *API) oauthRedirectHandler(w http.ResponseWriter, r *http.Request) {

_, _ = a.p.updateAutomutingOnUserConnect(mmUserID)

bundlePath, err := a.p.API.GetBundlePath()
if err != nil {
a.p.API.LogWarn("Failed to get bundle path.", "error", err.Error())
return
}

t, err := template.ParseFiles(filepath.Join(bundlePath, "assets/info-page/index.html"))
if err != nil {
a.p.API.LogError("unable to parse the template", "error", err.Error())
http.Error(w, "unable to view the primary platform selection page", http.StatusInternalServerError)
}

err = t.Execute(w, struct {
ServerURL string
APIEndPoint string
QueryParamPrimaryPlatform string
}{
ServerURL: a.p.GetURL(),
APIEndPoint: APIChoosePrimaryPlatform,
QueryParamPrimaryPlatform: QueryParamPrimaryPlatform,
})
if err != nil {
a.p.API.LogError("unable to execute the template", "error", err.Error())
http.Error(w, "unable to view the primary platform selection page", http.StatusInternalServerError)
const userConnectedMessage = "Welcome to Mattermost for Microsoft Teams! Your conversations with MS Teams users are now synchronized."
post := &model.Post{
Id: stateArr[2],
Message: userConnectedMessage,
ChannelId: stateArr[3],
}
_ = a.p.GetAPI().UpdateEphemeralPost(mmUser.Id, post)
connectURL := a.p.GetURL() + "/primary-platform"
http.Redirect(w, r, connectURL, http.StatusSeeOther)
}

func (a *API) getConnectedUsers(w http.ResponseWriter, r *http.Request) {
Expand Down
5 changes: 3 additions & 2 deletions server/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -799,10 +799,11 @@ func TestConnect(t *testing.T) {
test.SetupStore(plugin.store.(*storemocks.Store))

w := httptest.NewRecorder()
endPoint := "/connect"
endPoint := "/connect?"
if test.isBot {
endPoint += "?isBot"
endPoint += "isBot&"
}
endPoint += "channel_id=123&post_id=456"
r := httptest.NewRequest(http.MethodGet, endPoint, nil)
r.Header.Add("Mattermost-User-Id", testutils.GetUserID())
plugin.ServeHTTP(nil, w, r)
Expand Down
4 changes: 4 additions & 0 deletions server/automute_preferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ func (p *Plugin) updateAutomutingOnPreferencesChanged(_ *plugin.Context, prefere
continue
}

p.notifyUserTeamsPrimary(userID)

if _, err := p.enableAutomute(userID); err != nil {
p.API.LogWarn(
"Unable to mute channels for a user who set their primary platform to Teams",
Expand All @@ -33,6 +35,8 @@ func (p *Plugin) updateAutomutingOnPreferencesChanged(_ *plugin.Context, prefere
}

for _, userID := range userIDsToDisable {
p.notifyUserMattermostPrimary(userID)

_, err := p.disableAutomute(userID)
if err != nil {
p.API.LogWarn(
Expand Down
75 changes: 45 additions & 30 deletions server/automute_preferences_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,52 @@ package main

import (
"testing"
"time"

"github.com/mattermost/mattermost-plugin-msteams/server/store/storemodels"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/plugin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

func TestUpdateAutomutingOnPreferencesChanged(t *testing.T) {
th := setupTestHelper(t)

team := th.SetupTeam(t)

setup := func(t *testing.T) (*Plugin, *model.User, *model.Channel, *model.Channel, *model.Channel) {
t.Helper()
th.Reset(t)

p := newAutomuteTestPlugin(t)
user := th.SetupUser(t, team)
err := th.p.store.SetUserInfo(user.Id, "team_user_id", &oauth2.Token{AccessToken: "token", Expiry: time.Now().Add(10 * time.Minute)})
require.NoError(t, err)

user := &model.User{Id: model.NewId()}
mockUserConnected(p, user.Id)
linkedChannel := th.SetupPublicChannel(t, team, WithMembers(user))

linkedChannel, appErr := p.API.CreateChannel(&model.Channel{Id: model.NewId(), Type: model.ChannelTypeOpen})
require.Nil(t, appErr)
_, appErr = p.API.AddUserToChannel(linkedChannel.Id, user.Id, "")
require.Nil(t, appErr)
mockLinkedChannel(p, linkedChannel)
channelLink := storemodels.ChannelLink{
MattermostTeamID: team.Id,
MattermostChannelID: linkedChannel.Id,
MSTeamsTeam: model.NewId(),
MSTeamsChannel: model.NewId(),
Creator: user.Id,
}
err = th.p.store.StoreChannelLink(&channelLink)
require.NoError(t, err)

unlinkedChannel, appErr := p.API.CreateChannel(&model.Channel{Id: model.NewId(), Type: model.ChannelTypeOpen})
require.Nil(t, appErr)
_, appErr = p.API.AddUserToChannel(unlinkedChannel.Id, user.Id, "")
require.Nil(t, appErr)
mockUnlinkedChannel(p, unlinkedChannel)
unlinkedChannel := th.SetupPublicChannel(t, team, WithMembers(user))

dmChannel, appErr := p.API.GetDirectChannel(user.Id, model.NewId())
otherUser := th.SetupUser(t, team)
dmChannel, appErr := th.p.API.GetDirectChannel(user.Id, otherUser.Id)
require.Nil(t, appErr)

assertChannelNotAutomuted(t, p, linkedChannel.Id, user.Id)
assertChannelNotAutomuted(t, p, unlinkedChannel.Id, user.Id)
assertChannelNotAutomuted(t, p, dmChannel.Id, user.Id)
assertChannelNotAutomuted(t, th.p, linkedChannel.Id, user.Id)
assertChannelNotAutomuted(t, th.p, unlinkedChannel.Id, user.Id)
assertChannelNotAutomuted(t, th.p, dmChannel.Id, user.Id)

return p, user, linkedChannel, unlinkedChannel, dmChannel
return th.p, user, linkedChannel, unlinkedChannel, dmChannel
}

t.Run("should mute linked channels when their primary platform changes from MM to MS Teams", func(t *testing.T) {
Expand All @@ -57,6 +67,7 @@ func TestUpdateAutomutingOnPreferencesChanged(t *testing.T) {
assertChannelNotAutomuted(t, p, linkedChannel.Id, user.Id)
assertChannelNotAutomuted(t, p, unlinkedChannel.Id, user.Id)
assertChannelNotAutomuted(t, p, dmChannel.Id, user.Id)
th.assertDMFromUser(t, p.userID, user.Id, userChoseMattermostPrimaryMessage)

p.PreferencesHaveChanged(&plugin.Context{}, []model.Preference{
{
Expand All @@ -72,6 +83,7 @@ func TestUpdateAutomutingOnPreferencesChanged(t *testing.T) {
assertChannelAutomuted(t, p, linkedChannel.Id, user.Id)
assertChannelNotAutomuted(t, p, unlinkedChannel.Id, user.Id)
assertChannelAutomuted(t, p, dmChannel.Id, user.Id)
th.assertDMFromUser(t, p.userID, user.Id, userChoseTeamsPrimaryMessage)
})

t.Run("should unmute linked channels when their primary platform changes from MS Teams to MM", func(t *testing.T) {
Expand All @@ -91,6 +103,7 @@ func TestUpdateAutomutingOnPreferencesChanged(t *testing.T) {
assertChannelAutomuted(t, p, linkedChannel.Id, user.Id)
assertChannelNotAutomuted(t, p, unlinkedChannel.Id, user.Id)
assertChannelAutomuted(t, p, dmChannel.Id, user.Id)
th.assertDMFromUser(t, p.userID, user.Id, userChoseTeamsPrimaryMessage)

p.PreferencesHaveChanged(&plugin.Context{}, []model.Preference{
{
Expand All @@ -106,6 +119,7 @@ func TestUpdateAutomutingOnPreferencesChanged(t *testing.T) {
assertChannelNotAutomuted(t, p, linkedChannel.Id, user.Id)
assertChannelNotAutomuted(t, p, unlinkedChannel.Id, user.Id)
assertChannelNotAutomuted(t, p, dmChannel.Id, user.Id)
th.assertDMFromUser(t, p.userID, user.Id, userChoseMattermostPrimaryMessage)
})

t.Run("should do nothing when unrelated preferences change", func(t *testing.T) {
Expand All @@ -119,20 +133,21 @@ func TestUpdateAutomutingOnPreferencesChanged(t *testing.T) {
Value: "full",
},
})
th.assertNoDMFromUser(t, p.userID, user.Id)
})

t.Run("should do nothing when an unconnected user turns on automuting", func(t *testing.T) {
p, _, linkedChannel, unlinkedChannel, _ := setup(t)

unconnectedUser := &model.User{Id: model.NewId()}
mockUserNotConnected(p, unconnectedUser.Id)
unconnectedUser := th.SetupUser(t, team)
otherUser := th.SetupUser(t, team)

_, appErr := p.API.AddUserToChannel(linkedChannel.Id, unconnectedUser.Id, "")
require.Nil(t, appErr)
_, appErr = p.API.AddUserToChannel(unlinkedChannel.Id, unconnectedUser.Id, "")
require.Nil(t, appErr)

dmChannel, appErr := p.API.GetDirectChannel(unconnectedUser.Id, model.NewId())
dmChannel, appErr := p.API.GetDirectChannel(unconnectedUser.Id, otherUser.Id)
require.Nil(t, appErr)

p.PreferencesHaveChanged(&plugin.Context{}, []model.Preference{
Expand All @@ -149,21 +164,21 @@ func TestUpdateAutomutingOnPreferencesChanged(t *testing.T) {
assertChannelNotAutomuted(t, p, linkedChannel.Id, unconnectedUser.Id)
assertChannelNotAutomuted(t, p, unlinkedChannel.Id, unconnectedUser.Id)
assertChannelNotAutomuted(t, p, dmChannel.Id, unconnectedUser.Id)
th.assertNoDMFromUser(t, p.userID, unconnectedUser.Id)
})

t.Run("should not affect other users when a connected user turns on automuting", func(t *testing.T) {
p, user, linkedChannel, unlinkedChannel, _ := setup(t)

connectedUser := &model.User{Id: model.NewId()}
mockUserConnected(p, connectedUser.Id)
connectedUser := th.SetupUser(t, team)
th.ConnectUser(t, connectedUser.Id, "t"+connectedUser.Id)

_, appErr := p.API.AddUserToChannel(linkedChannel.Id, connectedUser.Id, "")
require.Nil(t, appErr)
_, appErr = p.API.AddUserToChannel(unlinkedChannel.Id, connectedUser.Id, "")
require.Nil(t, appErr)

unconnectedUser := &model.User{Id: model.NewId()}
mockUserNotConnected(p, unconnectedUser.Id)
unconnectedUser := th.SetupUser(t, team)

_, appErr = p.API.AddUserToChannel(linkedChannel.Id, unconnectedUser.Id, "")
require.Nil(t, appErr)
Expand All @@ -180,6 +195,7 @@ func TestUpdateAutomutingOnPreferencesChanged(t *testing.T) {
})

assertUserHasAutomuteEnabled(t, p, user.Id)
th.assertDMFromUser(t, p.userID, user.Id, userChoseTeamsPrimaryMessage)

assertChannelAutomuted(t, p, linkedChannel.Id, user.Id)
assertChannelNotAutomuted(t, p, unlinkedChannel.Id, user.Id)
Expand All @@ -201,12 +217,9 @@ func TestUpdateAutomutingOnPreferencesChanged(t *testing.T) {
numChannels := 1000
channels := make([]*model.Channel, numChannels)
for i := 0; i < numChannels; i++ {
channel, appErr := p.API.CreateChannel(&model.Channel{Id: model.NewId(), Type: model.ChannelTypeOpen})
require.Nil(t, appErr)
_, appErr = p.API.AddUserToChannel(channel.Id, user.Id, "")
require.Nil(t, appErr)
channel := th.SetupPublicChannel(t, team, WithMembers(user))

mockLinkedChannel(p, channel)
th.LinkChannel(t, team, channel, user)

channels[i] = channel
}
Expand All @@ -225,6 +238,7 @@ func TestUpdateAutomutingOnPreferencesChanged(t *testing.T) {
})

assertUserHasAutomuteEnabled(t, p, user.Id)
th.assertDMFromUser(t, p.userID, user.Id, userChoseTeamsPrimaryMessage)

for _, channel := range channels {
assertChannelAutomuted(t, p, channel.Id, user.Id)
Expand All @@ -240,6 +254,7 @@ func TestUpdateAutomutingOnPreferencesChanged(t *testing.T) {
})

assertUserHasAutomuteDisabled(t, p, user.Id)
th.assertDMFromUser(t, p.userID, user.Id, userChoseMattermostPrimaryMessage)

for _, channel := range channels {
assertChannelNotAutomuted(t, p, channel.Id, user.Id)
Expand Down
Loading

0 comments on commit f58b482

Please sign in to comment.