Skip to content
This repository was archived by the owner on Dec 13, 2018. It is now read-only.

Commit d4ece29

Browse files
committed
refactor GetAdditionalGroupsPath
This parses group file only once to process a list of groups instead of parsing once for each group. Also added an unit test for GetAdditionalGroupsPath Signed-off-by: Daniel, Dao Quang Minh <[email protected]>
1 parent 50603ca commit d4ece29

File tree

2 files changed

+136
-37
lines changed

2 files changed

+136
-37
lines changed

user/user.go

+45-37
Original file line numberDiff line numberDiff line change
@@ -349,51 +349,59 @@ func GetExecUser(userSpec string, defaults *ExecUser, passwd, group io.Reader) (
349349
return user, nil
350350
}
351351

352-
// GetAdditionalGroupsPath is a wrapper for GetAdditionalGroups. It reads data from the
353-
// given file path and uses that data as the arguments to GetAdditionalGroups.
352+
// GetAdditionalGroupsPath looks up a list of groups by name or group id
353+
// against the group file. If a group name cannot be found, an error will be
354+
// returned. If a group id cannot be found, it will be returned as-is.
354355
func GetAdditionalGroupsPath(additionalGroups []string, groupPath string) ([]int, error) {
355-
var groupIds []int
356-
357-
for _, ag := range additionalGroups {
358-
groupReader, err := os.Open(groupPath)
359-
if err != nil {
360-
return nil, fmt.Errorf("Failed to open group file: %v", err)
361-
}
362-
defer groupReader.Close()
363-
364-
groupId, err := GetAdditionalGroup(ag, groupReader)
365-
if err != nil {
366-
return nil, err
367-
}
368-
groupIds = append(groupIds, groupId)
356+
groupReader, err := os.Open(groupPath)
357+
if err != nil {
358+
return nil, fmt.Errorf("Failed to open group file: %v", err)
369359
}
360+
defer groupReader.Close()
370361

371-
return groupIds, nil
372-
}
373-
374-
// GetAdditionalGroup looks up the specified group in the passed groupReader.
375-
func GetAdditionalGroup(additionalGroup string, groupReader io.Reader) (int, error) {
376362
groups, err := ParseGroupFilter(groupReader, func(g Group) bool {
377-
return g.Name == additionalGroup || strconv.Itoa(g.Gid) == additionalGroup
363+
for _, ag := range additionalGroups {
364+
if g.Name == ag || strconv.Itoa(g.Gid) == ag {
365+
return true
366+
}
367+
}
368+
return false
378369
})
379370
if err != nil {
380-
return -1, fmt.Errorf("Unable to find additional groups %v: %v", additionalGroup, err)
371+
return nil, fmt.Errorf("Unable to find additional groups %v: %v", additionalGroups, err)
381372
}
382-
if groups != nil && len(groups) > 0 {
383-
// if we found any group entries that matched our filter, let's take the first one as "correct"
384-
return groups[0].Gid, nil
385-
} else {
386-
// we asked for a group but didn't find id... let's check to see if we wanted a numeric group
387-
addGroup, err := strconv.Atoi(additionalGroup)
388-
if err != nil {
389-
// not numeric - we have to bail
390-
return -1, fmt.Errorf("Unable to find group %v", additionalGroup)
391-
}
392373

393-
// Ensure gid is inside gid range.
394-
if addGroup < minId || addGroup > maxId {
395-
return -1, ErrRange
374+
gidMap := make(map[int]struct{})
375+
for _, ag := range additionalGroups {
376+
var found bool
377+
for _, g := range groups {
378+
// if we found a matched group either by name or gid, take the
379+
// first matched as correct
380+
if g.Name == ag || strconv.Itoa(g.Gid) == ag {
381+
if _, ok := gidMap[g.Gid]; !ok {
382+
gidMap[g.Gid] = struct{}{}
383+
found = true
384+
break
385+
}
386+
}
387+
}
388+
// we asked for a group but didn't find it. let's check to see
389+
// if we wanted a numeric group
390+
if !found {
391+
gid, err := strconv.Atoi(ag)
392+
if err != nil {
393+
return nil, fmt.Errorf("Unable to find group %s", ag)
394+
}
395+
// Ensure gid is inside gid range.
396+
if gid < minId || gid > maxId {
397+
return nil, ErrRange
398+
}
399+
gidMap[gid] = struct{}{}
396400
}
397-
return addGroup, nil
398401
}
402+
gids := []int{}
403+
for gid := range gidMap {
404+
gids = append(gids, gid)
405+
}
406+
return gids, nil
399407
}

user/user_test.go

+91
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
package user
22

33
import (
4+
"fmt"
45
"io"
6+
"io/ioutil"
57
"reflect"
8+
"sort"
9+
"strconv"
610
"strings"
711
"testing"
812
)
@@ -350,3 +354,90 @@ this is just some garbage data
350354
}
351355
}
352356
}
357+
358+
func TestGetAdditionalGroupsPath(t *testing.T) {
359+
const groupContent = `
360+
root:x:0:root
361+
adm:x:43:
362+
grp:x:1234:root,adm
363+
adm:x:4343:root,adm-duplicate
364+
this is just some garbage data
365+
`
366+
tests := []struct {
367+
groups []string
368+
expected []int
369+
hasError bool
370+
}{
371+
{
372+
// empty group
373+
groups: []string{},
374+
expected: []int{},
375+
},
376+
{
377+
// single group
378+
groups: []string{"adm"},
379+
expected: []int{43},
380+
},
381+
{
382+
// multiple groups
383+
groups: []string{"adm", "grp"},
384+
expected: []int{43, 1234},
385+
},
386+
{
387+
// invalid group
388+
groups: []string{"adm", "grp", "not-exist"},
389+
expected: nil,
390+
hasError: true,
391+
},
392+
{
393+
// group with numeric id
394+
groups: []string{"43"},
395+
expected: []int{43},
396+
},
397+
{
398+
// group with unknown numeric id
399+
groups: []string{"adm", "10001"},
400+
expected: []int{43, 10001},
401+
},
402+
{
403+
// groups specified twice with numeric and name
404+
groups: []string{"adm", "43"},
405+
expected: []int{43},
406+
},
407+
{
408+
// groups with too small id
409+
groups: []string{"-1"},
410+
expected: nil,
411+
hasError: true,
412+
},
413+
{
414+
// groups with too large id
415+
groups: []string{strconv.Itoa(1 << 31)},
416+
expected: nil,
417+
hasError: true,
418+
},
419+
}
420+
421+
for _, test := range tests {
422+
tmpFile, err := ioutil.TempFile("", "get-additional-groups-path")
423+
if err != nil {
424+
t.Error(err)
425+
}
426+
fmt.Fprint(tmpFile, groupContent)
427+
tmpFile.Close()
428+
429+
gids, err := GetAdditionalGroupsPath(test.groups, tmpFile.Name())
430+
if test.hasError && err == nil {
431+
t.Errorf("Parse(%#v) expects error but has none", test)
432+
continue
433+
}
434+
if !test.hasError && err != nil {
435+
t.Errorf("Parse(%#v) has error %v", test, err)
436+
continue
437+
}
438+
sort.Sort(sort.IntSlice(gids))
439+
if !reflect.DeepEqual(gids, test.expected) {
440+
t.Errorf("Gids(%v), expect %v from groups %v", gids, test.expected, test.groups)
441+
}
442+
}
443+
}

0 commit comments

Comments
 (0)