diff --git a/cli/cli_connect_test.go b/cli/cli_connect_test.go index cef54461..9f98e068 100644 --- a/cli/cli_connect_test.go +++ b/cli/cli_connect_test.go @@ -7,7 +7,6 @@ import ( "io" "os" "strings" - "sync" "testing" "github.com/NordSecurity/nordvpn-linux/client/config" @@ -31,21 +30,16 @@ func captureOutput(f func()) (string, error) { os.Stdout = stdout os.Stderr = stderr }() + os.Stdout = writer os.Stderr = writer - out := make(chan string) - wg := new(sync.WaitGroup) - wg.Add(1) - go func() { - var buf bytes.Buffer - wg.Done() - io.Copy(&buf, reader) - out <- buf.String() - }() - wg.Wait() + f() - writer.Close() - return strings.TrimSuffix(<-out, "\n"), nil + + writer.Close() // close to unblock io.Copy(&buf, reader) + var buf bytes.Buffer + io.Copy(&buf, reader) + return strings.TrimSuffix(buf.String(), "\n"), nil } type mockDaemonClient struct { diff --git a/cli/cli_countries_test.go b/cli/cli_countries_test.go index 33c38dfe..f2bbfca7 100644 --- a/cli/cli_countries_test.go +++ b/cli/cli_countries_test.go @@ -29,7 +29,7 @@ func TestCountriesList(t *testing.T) { expectedError: formatError(fmt.Errorf(MsgListIsEmpty, "countries")), }, { - name: "counties list", + name: "countries list", expected: "France, Germany", countries: []string{"France", "Germany"}, }, diff --git a/cli/cli_groups_test.go b/cli/cli_groups_test.go new file mode 100644 index 00000000..c604de62 --- /dev/null +++ b/cli/cli_groups_test.go @@ -0,0 +1,53 @@ +package cli + +import ( + "context" + "flag" + "fmt" + "testing" + + "github.com/NordSecurity/nordvpn-linux/client/config" + "github.com/NordSecurity/nordvpn-linux/test/category" + "github.com/stretchr/testify/assert" + "github.com/urfave/cli/v2" +) + +func TestGroupsList(t *testing.T) { + category.Set(t, category.Unit) + mockClient := mockDaemonClient{} + c := cmd{&mockClient, nil, nil, "", nil, config.Config{}, nil} + + tests := []struct { + name string + groups []string + expected string + input string + expectedError error + }{ + { + name: "error response", + expectedError: formatError(fmt.Errorf(MsgListIsEmpty, "server groups")), + }, + { + name: "groups list", + expected: "group1, group2", + groups: []string{"group1", "group2"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + app := cli.NewApp() + set := flag.NewFlagSet("test", 0) + mockClient.groups = test.groups + ctx := cli.NewContext(app, set, &cli.Context{Context: context.Background()}) + + result, err := captureOutput(func() { + err := c.Groups(ctx) + assert.Equal(t, test.expectedError, err) + }) + assert.Nil(t, err) + assert.Equal(t, test.expected, result) + }) + } +} diff --git a/daemon/rpc_countries.go b/daemon/rpc_countries.go index b4fac889..1a6c9c16 100644 --- a/daemon/rpc_countries.go +++ b/daemon/rpc_countries.go @@ -81,20 +81,20 @@ func (r *RPC) Countries(ctx context.Context, in *pb.CountriesRequest) (*pb.Paylo }, nil } - if countries, ok := r.dm.GetAppData().CountryNames[in.GetObfuscate()][cfg.AutoConnectData.Protocol]; ok { - var countryNames []string - for country := range countries.Iter() { - countryNames = append(countryNames, country.(string)) - } - sort.Strings(countryNames) + countries, ok := r.dm.GetAppData().CountryNames[in.GetObfuscate()][cfg.AutoConnectData.Protocol] + if !ok { return &pb.Payload{ - Type: internal.CodeSuccess, - Data: countryNames, + Type: internal.CodeEmptyPayloadError, }, nil } - + var countryNames []string + for country := range countries.Iter() { + countryNames = append(countryNames, country.(string)) + } + sort.Strings(countryNames) return &pb.Payload{ - Type: internal.CodeEmptyPayloadError, + Type: internal.CodeSuccess, + Data: countryNames, }, nil }