diff --git a/internal/dbtools/membership_enumeration.go b/internal/dbtools/membership_enumeration.go index 49db895..a9544d5 100644 --- a/internal/dbtools/membership_enumeration.go +++ b/internal/dbtools/membership_enumeration.go @@ -7,6 +7,7 @@ import ( "github.com/metal-toolbox/governor-api/internal/models" "github.com/volatiletech/null/v8" + "github.com/volatiletech/sqlboiler/v4/boil" "github.com/volatiletech/sqlboiler/v4/queries" "github.com/volatiletech/sqlboiler/v4/queries/qm" ) @@ -168,7 +169,7 @@ type EnumeratedMembership struct { } // GetMembershipsForUser returns a fully enumerated list of memberships for a user, optionally with sqlboiler's generated models populated -func GetMembershipsForUser(ctx context.Context, db *sql.DB, userID string, shouldPopulateAllModels bool) ([]EnumeratedMembership, error) { +func GetMembershipsForUser(ctx context.Context, db boil.ContextExecutor, userID string, shouldPopulateAllModels bool) ([]EnumeratedMembership, error) { enumeratedMemberships := []EnumeratedMembership{} err := queries.Raw(membershipsByUserQuery, userID).Bind(ctx, db, &enumeratedMemberships) @@ -189,7 +190,7 @@ func GetMembershipsForUser(ctx context.Context, db *sql.DB, userID string, shoul } // GetMembersOfGroup returns a fully enumerated list of memberships in a group, optionally with sqlboiler's generated models populated -func GetMembersOfGroup(ctx context.Context, db *sql.DB, groupID string, shouldPopulateAllModels bool) ([]EnumeratedMembership, error) { +func GetMembersOfGroup(ctx context.Context, db boil.ContextExecutor, groupID string, shouldPopulateAllModels bool) ([]EnumeratedMembership, error) { enumeratedMemberships := []EnumeratedMembership{} err := queries.Raw(membershipsByGroupQuery, groupID).Bind(ctx, db, &enumeratedMemberships) @@ -210,7 +211,7 @@ func GetMembersOfGroup(ctx context.Context, db *sql.DB, groupID string, shouldPo } // GetAllGroupMemberships returns a fully enumerated list of all memberships in the database, optionally with sqlboiler's generated models populated (use with caution, potentially lots of data) -func GetAllGroupMemberships(ctx context.Context, db *sql.DB, shouldPopulateAllModels bool) ([]EnumeratedMembership, error) { +func GetAllGroupMemberships(ctx context.Context, db boil.ContextExecutor, shouldPopulateAllModels bool) ([]EnumeratedMembership, error) { enumeratedMemberships := []EnumeratedMembership{} err := queries.Raw(allMembershipsQuery).Bind(ctx, db, &enumeratedMemberships) @@ -231,7 +232,7 @@ func GetAllGroupMemberships(ctx context.Context, db *sql.DB, shouldPopulateAllMo } // HierarchyWouldCreateCycle returns true if a given new parent->member relationship would create a cycle in the database -func HierarchyWouldCreateCycle(ctx context.Context, db *sql.DB, parentGroupID, memberGroupID string) (bool, error) { +func HierarchyWouldCreateCycle(ctx context.Context, db boil.ContextExecutor, parentGroupID, memberGroupID string) (bool, error) { hierarchies := make(map[string][]string) hierarchyRows, err := models.GroupHierarchies().All(ctx, db) @@ -306,7 +307,7 @@ func FindMemberDiff(before, after []EnumeratedMembership) []EnumeratedMembership return uniqueMembersAfter } -func populateModels(ctx context.Context, db *sql.DB, memberships []EnumeratedMembership) ([]EnumeratedMembership, error) { +func populateModels(ctx context.Context, db boil.ContextExecutor, memberships []EnumeratedMembership) ([]EnumeratedMembership, error) { groupIDSet := make(map[string]bool) userIDSet := make(map[string]bool) diff --git a/pkg/api/v1alpha1/group_membership.go b/pkg/api/v1alpha1/group_membership.go index 71d499b..2b6b9cf 100644 --- a/pkg/api/v1alpha1/group_membership.go +++ b/pkg/api/v1alpha1/group_membership.go @@ -182,7 +182,7 @@ func (r *Router) addGroupMember(c *gin.Context) { return } - membershipsBefore, err := dbtools.GetMembershipsForUser(c, r.DB.DB, user.ID, false) + membershipsBefore, err := dbtools.GetMembershipsForUser(c, tx, user.ID, false) if err != nil { msg := "failed to compute new effective memberships: " + err.Error() @@ -232,7 +232,7 @@ func (r *Router) addGroupMember(c *gin.Context) { return } - membershipsAfter, err := dbtools.GetMembershipsForUser(c, r.DB.DB, user.ID, false) + membershipsAfter, err := dbtools.GetMembershipsForUser(c, tx, user.ID, false) if err != nil { msg := "failed to compute new effective memberships: " + err.Error() @@ -518,7 +518,7 @@ func (r *Router) removeGroupMember(c *gin.Context) { return } - membershipsBefore, err := dbtools.GetMembershipsForUser(c, r.DB.DB, user.ID, false) + membershipsBefore, err := dbtools.GetMembershipsForUser(c, tx, user.ID, false) if err != nil { msg := "failed to compute new effective memberships: " + err.Error() @@ -568,7 +568,7 @@ func (r *Router) removeGroupMember(c *gin.Context) { return } - membershipsAfter, err := dbtools.GetMembershipsForUser(c, r.DB.DB, user.ID, false) + membershipsAfter, err := dbtools.GetMembershipsForUser(c, tx, user.ID, false) if err != nil { msg := "failed to compute new effective memberships: " + err.Error() @@ -1028,7 +1028,7 @@ func (r *Router) processGroupRequest(c *gin.Context) { return } - membershipsBefore, err := dbtools.GetMembershipsForUser(c, r.DB.DB, user.ID, false) + membershipsBefore, err := dbtools.GetMembershipsForUser(c, tx, user.ID, false) if err != nil { msg := "failed to compute new effective memberships: " + err.Error() @@ -1090,7 +1090,7 @@ func (r *Router) processGroupRequest(c *gin.Context) { return } - membershipsAfter, err := dbtools.GetMembershipsForUser(c, r.DB.DB, user.ID, false) + membershipsAfter, err := dbtools.GetMembershipsForUser(c, tx, user.ID, false) if err != nil { msg := "failed to compute new effective memberships: " + err.Error()