Skip to content

Commit

Permalink
Merge pull request #66 from NorskHelsenett/fix/error--in-domain-provi…
Browse files Browse the repository at this point in the history
…ders

Fix/error  in domain providers
  • Loading branch information
havardelnan authored Apr 12, 2024
2 parents 589eaf6 + c88b054 commit 3800cb6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pkg/auth/userauth/activedirectory/ad.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func splitUserId(userId string) (string, string, error) {
if len(parts) != 2 {
return "", "", fmt.Errorf("invalid userId: %s", userId)
}
return parts[1], parts[0], nil
return parts[0], parts[1], nil
}

func checkUserAccountControl(userAccountControl string) error {
Expand Down
26 changes: 26 additions & 0 deletions pkg/auth/userauth/msgraph/msgraph.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package msgraph

import (
"context"
"fmt"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/NorskHelsenett/ror/pkg/helpers/kvcachehelper"
Expand Down Expand Up @@ -90,6 +92,8 @@ func (g *MsGraphClient) GetUser(userId string) (*identitymodels.User, error) {
}
}

addDomainpartToGroups(&groupnames, userId)

ret = &identitymodels.User{
Email: *user.GetUserPrincipalName(),
Name: *user.GetDisplayName(),
Expand All @@ -100,6 +104,19 @@ func (g *MsGraphClient) GetUser(userId string) (*identitymodels.User, error) {
return ret, nil
}

func addDomainpartToGroups(groupnames *[]string, userId string) {

_, domain, err := splitUserId(userId)
if err != nil {
domain = ""
}

// TODO: Add check if domainpart is allready part of the group name
for i, group := range *groupnames {
(*groupnames)[i] = group + "@" + domain
}
}

// getUser gets a user from the graph api
func (g *MsGraphClient) getUser(userId string, userChan chan<- models.Userable, errorChan chan<- error) {
user, err := g.Client.Users().ByUserId(userId).Get(context.Background(), nil)
Expand Down Expand Up @@ -143,6 +160,7 @@ func (g *MsGraphClient) getGroupDisplayNames(groups []string, groupCache CacheIn
}

}

return groupNames, nil
}

Expand All @@ -163,3 +181,11 @@ func (g *MsGraphClient) getGroupDisplayName(groupId string, groupsNameChan chan<
groupCache.Add(groupId, *group.GetDisplayName())
groupsNameChan <- *group.GetDisplayName()
}

func splitUserId(userId string) (string, string, error) {
parts := strings.Split(userId, "@")
if len(parts) != 2 {
return "", "", fmt.Errorf("invalid userId: %s", userId)
}
return parts[0], parts[1], nil
}

0 comments on commit 3800cb6

Please sign in to comment.