From 69c5a49238790252bd9213d712d928ca6a359dde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Gonz=C3=A1lez=20Di=20Antonio?= Date: Sun, 3 Nov 2024 16:55:50 +0100 Subject: [PATCH] chore: clean and improve code --- internal/scim/operations.go | 5 +- internal/scim/operations_test.go | 11 +++ internal/scim/scim.go | 78 ++++++++++---------- internal/scim/scim_test.go | 122 ++++++++++++++++++++++++++++++- mocks/scim/scim_mocks.go | 30 -------- 5 files changed, 174 insertions(+), 72 deletions(-) diff --git a/internal/scim/operations.go b/internal/scim/operations.go index 8fc303e0..5b7e1429 100644 --- a/internal/scim/operations.go +++ b/internal/scim/operations.go @@ -8,7 +8,7 @@ import ( // patchGroupOperations assembles the operations for patch groups // bases in the limits of operations we can execute in a single request. func patchGroupOperations(op, path string, pvs []patchValue, gms *model.GroupMembers) []*aws.PatchGroupRequest { - patchOperations := []*aws.PatchGroupRequest{} + patchOperations := make([]*aws.PatchGroupRequest, 0) if len(pvs) > MaxPatchGroupMembersPerRequest { for i := 0; i < len(pvs); i += MaxPatchGroupMembersPerRequest { @@ -33,9 +33,11 @@ func patchGroupOperations(op, path string, pvs []patchValue, gms *model.GroupMem }, }, } + patchOperations = append(patchOperations, patchGroupRequest) } } else { + patchGroupRequest := &aws.PatchGroupRequest{ Group: aws.Group{ ID: gms.Group.SCIMID, @@ -52,6 +54,7 @@ func patchGroupOperations(op, path string, pvs []patchValue, gms *model.GroupMem }, }, } + patchOperations = append(patchOperations, patchGroupRequest) } diff --git a/internal/scim/operations_test.go b/internal/scim/operations_test.go index 5d075311..4ececf14 100644 --- a/internal/scim/operations_test.go +++ b/internal/scim/operations_test.go @@ -134,3 +134,14 @@ func Test_patchGroupOperations(t *testing.T) { }) } } + +func Benchmark_patchGroupOperations(b *testing.B) { + for i := 0; i < b.N; i++ { + patchGroupOperations("add", "members", patchValueGenerator(1, 350), &model.GroupMembers{ + Group: &model.Group{ + SCIMID: "016722b2be-ee23ed58-6e4e-4b2f-a94a-3ace8456a36e", + Name: "group 1", + }, + }) + } +} diff --git a/internal/scim/scim.go b/internal/scim/scim.go index 69e8791d..ab7d54be 100644 --- a/internal/scim/scim.go +++ b/internal/scim/scim.go @@ -18,9 +18,6 @@ type AWSSCIMProvider interface { // ListUsers lists users in SCIM Provider ListUsers(ctx context.Context, filter string) (*aws.ListUsersResponse, error) - // CreateUser creates a user in SCIM Provider - CreateUser(ctx context.Context, u *aws.CreateUserRequest) (*aws.CreateUserResponse, error) - // CreateOrGetUser creates a user in SCIM Provider CreateOrGetUser(ctx context.Context, u *aws.CreateUserRequest) (*aws.CreateUserResponse, error) @@ -39,9 +36,6 @@ type AWSSCIMProvider interface { // ListGroups lists groups in SCIM Provider ListGroups(ctx context.Context, filter string) (*aws.ListGroupsResponse, error) - // CreateGroup creates a group in SCIM Provider - CreateGroup(ctx context.Context, g *aws.CreateGroupRequest) (*aws.CreateGroupResponse, error) - // CreateOrGetGroup creates a group in SCIM Provider CreateOrGetGroup(ctx context.Context, g *aws.CreateGroupRequest) (*aws.CreateGroupResponse, error) @@ -81,17 +75,17 @@ func (s *Provider) GetGroups(ctx context.Context) (*model.GroupsResult, error) { groups := make([]*model.Group, len(groupsResponse.Resources)) for i, group := range groupsResponse.Resources { - e := model.GroupBuilder(). + g := model.GroupBuilder(). WithSCIMID(group.ID). WithName(group.DisplayName). WithIPID(group.ExternalID). Build() - groups[i] = e + groups[i] = g + } groupsResult := model.GroupsResultBuilder().WithResources(groups).Build() - slog.Debug("scim: GetGroups()", "groups", len(groups)) return groupsResult, nil @@ -99,6 +93,10 @@ func (s *Provider) GetGroups(ctx context.Context) (*model.GroupsResult, error) { // CreateGroups creates groups in SCIM Provider func (s *Provider) CreateGroups(ctx context.Context, gr *model.GroupsResult) (*model.GroupsResult, error) { + if gr == nil { + return nil, fmt.Errorf("scim: error creating groups, groups result is nil") + } + groups := make([]*model.Group, len(gr.Resources)) for i, group := range gr.Resources { @@ -114,18 +112,17 @@ func (s *Provider) CreateGroups(ctx context.Context, gr *model.GroupsResult) (*m return nil, fmt.Errorf("scim: error creating group: %w", err) } - e := model.GroupBuilder(). + g := model.GroupBuilder(). WithSCIMID(r.ID). WithName(group.Name). WithIPID(group.IPID). WithEmail(group.Email). Build() - groups[i] = e + groups[i] = g } groupsResult := model.GroupsResultBuilder().WithResources(groups).Build() - slog.Debug("scim: CreateGroups()", "groups", len(groups)) return groupsResult, nil @@ -162,14 +159,14 @@ func (s *Provider) UpdateGroups(ctx context.Context, gr *model.GroupsResult) (*m } // return the same group - e := model.GroupBuilder(). + g := model.GroupBuilder(). WithSCIMID(group.SCIMID). WithName(group.Name). WithIPID(group.IPID). WithEmail(group.Email). Build() - groups[i] = e + groups[i] = g } groupsResult := model.GroupsResultBuilder().WithResources(groups).Build() @@ -200,8 +197,8 @@ func (s *Provider) GetUsers(ctx context.Context) (*model.UsersResult, error) { users := make([]*model.User, len(usersResponse.Resources)) for i, user := range usersResponse.Resources { - e := buildUser(user) - users[i] = e + u := buildUser(user) + users[i] = u } usersResult := model.UsersResultBuilder().WithResources(users).Build() @@ -286,13 +283,13 @@ type patchValue struct { // CreateGroupsMembers creates groups members in SCIM Provider given a list of groups members func (s *Provider) CreateGroupsMembers(ctx context.Context, gmr *model.GroupsMembersResult) (*model.GroupsMembersResult, error) { - groupsMembers := make([]*model.GroupMembers, 0) + groupsMembers := make([]*model.GroupMembers, len(gmr.Resources)) - for _, groupMembers := range gmr.Resources { - members := make([]*model.Member, 0) - membersIDValue := []patchValue{} + for i, groupMembers := range gmr.Resources { + members := make([]*model.Member, len(groupMembers.Resources)) + membersIDValue := make([]patchValue, len(groupMembers.Resources)) - for _, member := range groupMembers.Resources { + for j, member := range groupMembers.Resources { if member.SCIMID == "" { u, err := s.scim.GetUserByUserName(ctx, member.Email) if err != nil { @@ -301,11 +298,11 @@ func (s *Provider) CreateGroupsMembers(ctx context.Context, gmr *model.GroupsMem member.SCIMID = u.ID } - membersIDValue = append(membersIDValue, patchValue{ + membersIDValue[j] = patchValue{ Value: member.SCIMID, - }) + } - e := model.MemberBuilder(). + m := model.MemberBuilder(). WithIPID(member.IPID). WithSCIMID(member.SCIMID). WithEmail(member.Email). @@ -313,16 +310,15 @@ func (s *Provider) CreateGroupsMembers(ctx context.Context, gmr *model.GroupsMem Build() slog.Warn("adding member to group", "group", groupMembers.Group.Name, "email", member.Email) - members = append(members, e) - + members[j] = m } - e := model.GroupMembersBuilder(). + gm := model.GroupMembersBuilder(). WithGroup(groupMembers.Group). WithResources(members). Build() - groupsMembers = append(groupsMembers, e) + groupsMembers[i] = gm patchOperations := patchGroupOperations("add", "members", membersIDValue, groupMembers) @@ -397,9 +393,9 @@ func (s *Provider) GetGroupsMembers(ctx context.Context, gr *model.GroupsResult) } for _, gr := range lgr.Resources { - members := make([]*model.Member, 0) + members := make([]*model.Member, len(gr.Members)) - for _, member := range gr.Members { + for j, member := range gr.Members { u, err := s.scim.GetUser(ctx, member.Value) if err != nil { return nil, fmt.Errorf("scim: error getting user: %s, error %w", member.Value, err) @@ -410,15 +406,15 @@ func (s *Provider) GetGroupsMembers(ctx context.Context, gr *model.GroupsResult) WithEmail(u.Emails[0].Value). Build() - members = append(members, m) + members[j] = m } - e := model.GroupMembersBuilder(). + gms := model.GroupMembersBuilder(). WithGroup(group). WithResources(members). Build() - groupMembers = append(groupMembers, e) + groupMembers = append(groupMembers, gms) } } @@ -431,22 +427,23 @@ func (s *Provider) GetGroupsMembers(ctx context.Context, gr *model.GroupsResult) // GetGroupsMembersBruteForce returns a list of groups and their members from the SCIM Provider // NOTE: this is an bad alternative to the method GetGroupsMembers, because read the note in the method. func (s *Provider) GetGroupsMembersBruteForce(ctx context.Context, gr *model.GroupsResult, ur *model.UsersResult) (*model.GroupsMembersResult, error) { - groupMembers := make([]*model.GroupMembers, 0) + groupMembers := make([]*model.GroupMembers, len(gr.Resources)) // brute force implemented here thanks to the fxxckin' aws sso scim api - for _, group := range gr.Resources { + for i, group := range gr.Resources { members := make([]*model.Member, 0) for _, user := range ur.Resources { // https://docs.aws.amazon.com/singlesignon/latest/developerguide/listgroups.html - f := fmt.Sprintf("id eq %q and members eq %q", group.SCIMID, user.SCIMID) - lgr, err := s.scim.ListGroups(ctx, f) + filter := fmt.Sprintf("id eq %q and members eq %q", group.SCIMID, user.SCIMID) + lgr, err := s.scim.ListGroups(ctx, filter) if err != nil { return nil, fmt.Errorf("scim: error listing groups: %w", err) } - if lgr.TotalResults > 0 { // crazy thing of the AWS SSO SCIM API, it doesn't return the member into the Resources array + // AWS SSO SCIM API, it doesn't return the member into the Resources array + if lgr.TotalResults > 0 { m := model.MemberBuilder(). WithIPID(user.IPID). WithSCIMID(user.SCIMID). @@ -460,12 +457,13 @@ func (s *Provider) GetGroupsMembersBruteForce(ctx context.Context, gr *model.Gro members = append(members, m) } } - e := model.GroupMembersBuilder(). + + gms := model.GroupMembersBuilder(). WithGroup(group). WithResources(members). Build() - groupMembers = append(groupMembers, e) + groupMembers[i] = gms } slog.Debug("scim: GetGroupsMembersBruteForce()", "groups_members", len(groupMembers)) diff --git a/internal/scim/scim_test.go b/internal/scim/scim_test.go index 0d7d3439..a07cbb8f 100644 --- a/internal/scim/scim_test.go +++ b/internal/scim/scim_test.go @@ -165,11 +165,26 @@ func TestCreateGroups(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() + t.Run("Should return a error when model.GroupsResult is nil", func(t *testing.T) { + mockSCIM := mocks.NewMockAWSSCIMProvider(mockCtrl) + + svc, err := NewProvider(mockSCIM) + if err != nil { + t.Fatalf("error creating provider: %v", err) + } + gr, err := svc.CreateGroups(context.TODO(), nil) + assert.Error(t, err) + assert.Nil(t, gr) + }) + t.Run("Should do nothing with empty GroupsResult", func(t *testing.T) { mockSCIM := mocks.NewMockAWSSCIMProvider(mockCtrl) empty := &model.GroupsResult{} - svc, _ := NewProvider(mockSCIM) + svc, err := NewProvider(mockSCIM) + if err != nil { + t.Fatalf("error creating provider: %v", err) + } gr, err := svc.CreateGroups(context.TODO(), empty) assert.NoError(t, err) assert.NotNil(t, gr) @@ -1677,6 +1692,7 @@ func TestGetGroupsMembers(t *testing.T) { }, }, } + filter := fmt.Sprintf("displayName eq %q", grp.Resources[0].Name) lgr := &aws.ListGroupsResponse{ Resources: []*aws.Group{ @@ -1713,6 +1729,110 @@ func TestGetGroupsMembers(t *testing.T) { assert.Error(t, err) assert.Nil(t, got) }) + + t.Run("Should call ListGroups and GetUser 2 time and no return error", func(t *testing.T) { + mockSCIM := mocks.NewMockAWSSCIMProvider(mockCtrl) + grp := &model.GroupsResult{ + Items: 2, + Resources: []*model.Group{ + { + IPID: "1", + Name: "group 1", + Email: "group.1@mail.com", + }, + { + IPID: "2", + Name: "group 2", + Email: "group.2@mail.com", + }, + }, + } + filter1 := fmt.Sprintf("displayName eq %q", grp.Resources[0].Name) + filter2 := fmt.Sprintf("displayName eq %q", grp.Resources[1].Name) + + lgr1 := &aws.ListGroupsResponse{ + Resources: []*aws.Group{ + { + ID: "1", + DisplayName: grp.Resources[0].Name, + Members: []*aws.Member{ + { + Value: "1", + }, + }, + }, + }, + } + + lgr2 := &aws.ListGroupsResponse{ + Resources: []*aws.Group{ + { + ID: "2", + DisplayName: grp.Resources[1].Name, + Members: []*aws.Member{ + { + Value: "2", + }, + }, + }, + }, + } + + gur1 := &aws.GetUserResponse{ + Emails: []aws.Email{ + { + Value: "user.1@mail.com", + }, + { + Value: "user.2@mail.com", + }, + }, + } + + gur2 := &aws.GetUserResponse{ + Emails: []aws.Email{ + { + Value: "user.3@mail.com", + }, + { + Value: "user.4@mail.com", + }, + }, + } + + ctx := context.TODO() + mockSCIM.EXPECT().ListGroups(ctx, filter1).Return(lgr1, nil).Times(1) + mockSCIM.EXPECT().ListGroups(ctx, filter2).Return(lgr2, nil).Times(1) + mockSCIM.EXPECT().GetUser(ctx, lgr1.Resources[0].Members[0].Value).Return(gur1, nil).Times(1) + mockSCIM.EXPECT().GetUser(ctx, lgr2.Resources[0].Members[0].Value).Return(gur2, nil).Times(1) + + gr := &model.GroupsResult{ + Items: 2, + Resources: []*model.Group{ + { + IPID: "1", + SCIMID: "1", + Name: "group 1", + Email: "group.1@mail.com", + }, + { + IPID: "2", + SCIMID: "2", + Name: "group 2", + Email: "group.2@mail.com", + }, + }, + } + + svc, _ := NewProvider(mockSCIM) + got, err := svc.GetGroupsMembers(ctx, gr) + assert.NoError(t, err) + assert.NotNil(t, got) + + assert.Equal(t, 2, len(got.Resources)) + assert.Equal(t, 1, len(got.Resources[0].Resources)) + assert.Equal(t, 1, len(got.Resources[1].Resources)) + }) } func TestGetGroupsMembersBruteForce(t *testing.T) { diff --git a/mocks/scim/scim_mocks.go b/mocks/scim/scim_mocks.go index 386245e5..279a7fe8 100644 --- a/mocks/scim/scim_mocks.go +++ b/mocks/scim/scim_mocks.go @@ -35,21 +35,6 @@ func (m *MockAWSSCIMProvider) EXPECT() *MockAWSSCIMProviderMockRecorder { return m.recorder } -// CreateGroup mocks base method. -func (m *MockAWSSCIMProvider) CreateGroup(ctx context.Context, g *aws.CreateGroupRequest) (*aws.CreateGroupResponse, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateGroup", ctx, g) - ret0, _ := ret[0].(*aws.CreateGroupResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateGroup indicates an expected call of CreateGroup. -func (mr *MockAWSSCIMProviderMockRecorder) CreateGroup(ctx, g interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGroup", reflect.TypeOf((*MockAWSSCIMProvider)(nil).CreateGroup), ctx, g) -} - // CreateOrGetGroup mocks base method. func (m *MockAWSSCIMProvider) CreateOrGetGroup(ctx context.Context, g *aws.CreateGroupRequest) (*aws.CreateGroupResponse, error) { m.ctrl.T.Helper() @@ -80,21 +65,6 @@ func (mr *MockAWSSCIMProviderMockRecorder) CreateOrGetUser(ctx, u interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrGetUser", reflect.TypeOf((*MockAWSSCIMProvider)(nil).CreateOrGetUser), ctx, u) } -// CreateUser mocks base method. -func (m *MockAWSSCIMProvider) CreateUser(ctx context.Context, u *aws.CreateUserRequest) (*aws.CreateUserResponse, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateUser", ctx, u) - ret0, _ := ret[0].(*aws.CreateUserResponse) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateUser indicates an expected call of CreateUser. -func (mr *MockAWSSCIMProviderMockRecorder) CreateUser(ctx, u interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUser", reflect.TypeOf((*MockAWSSCIMProvider)(nil).CreateUser), ctx, u) -} - // DeleteGroup mocks base method. func (m *MockAWSSCIMProvider) DeleteGroup(ctx context.Context, id string) error { m.ctrl.T.Helper()