diff --git a/Makefile b/Makefile index 59e05561..ae51c262 100644 --- a/Makefile +++ b/Makefile @@ -18,14 +18,14 @@ cli: go build -o eduvpn-common-cli ./cmd/cli test: - go test -race ./... + go test -tags=cgotesting -race ./... clean: rm -rf lib go clean coverage: - go test -v -coverpkg=./... -coverprofile=common.cov ./... + go test -tags=cgotesting -v -coverpkg=./... -coverprofile=common.cov ./... go tool cover -func common.cov sloc: diff --git a/client/client.go b/client/client.go index e5a39c07..7096421e 100644 --- a/client/client.go +++ b/client/client.go @@ -528,6 +528,9 @@ func (c *Client) SetProfileID(pID string) error { if err != nil { return i18nerr.WrapInternalf(err, "Failed to set the profile ID: '%s'", pID) } + if _, ok := srv.Profiles.Map[pID]; !ok { + return i18nerr.WrapInternalf(err, "Failed to set the profile ID as it does not exist: '%s'", pID) + } srv.Profiles.Current = pID c.TrySave() return nil diff --git a/client/fsm.go b/client/fsm.go index c8858f91..7f16dce5 100644 --- a/client/fsm.go +++ b/client/fsm.go @@ -115,6 +115,7 @@ func newFSM( Transitions: []FSMTransition{ {To: StateMain, Description: "Authorized"}, {To: StateDisconnected, Description: "Cancel, was disconnected"}, + {To: StateGotConfig, Description: "Cancel, was got config"}, }, }, StateGettingConfig: FSMState{ @@ -140,6 +141,7 @@ func newFSM( Transitions: []FSMTransition{ {To: StateGettingConfig, Description: "Get a VPN config again"}, {To: StateConnecting, Description: "VPN is connecting"}, + {To: StateOAuthStarted, Description: "Renew"}, }, }, StateConnecting: FSMState{ diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 8eb71c15..6f074fbe 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -13,6 +13,7 @@ import ( "github.com/eduvpn/eduvpn-common/internal/version" "github.com/eduvpn/eduvpn-common/types/cookie" srvtypes "github.com/eduvpn/eduvpn-common/types/server" + "github.com/eduvpn/eduvpn-common/util" "github.com/pkg/browser" ) @@ -34,56 +35,6 @@ func openBrowser(data interface{}) { }() } -// GetLanguageMatched uses a map from language tags to strings to extract the right language given the tag -// It implements it according to https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY.md#language-matching -func GetLanguageMatched(langMap map[string]string, langTag string) string { - // If no map is given, return the empty string - if len(langMap) == 0 { - return "" - } - // Try to find the exact match - if val, ok := langMap[langTag]; ok { - return val - } - // Try to find a key that starts with the OS language setting - for k := range langMap { - if strings.HasPrefix(k, langTag) { - return langMap[k] - } - } - // Try to find a key that starts with the first part of the OS language (e.g. de-) - pts := strings.Split(langTag, "-") - // We have a "-" - if len(pts) > 1 { - for k := range langMap { - if strings.HasPrefix(k, pts[0]+"-") { - return langMap[k] - } - } - } - // search for just the language (e.g. de) - for k := range langMap { - if k == pts[0] { - return langMap[k] - } - } - - // Pick one that is deemed best, e.g. en-US or en, but note that not all languages are always available! - // We force an entry that is english exactly or with an english prefix - for k := range langMap { - if k == "en" || strings.HasPrefix(k, "en-") { - return langMap[k] - } - } - - // Otherwise just return one - for k := range langMap { - return langMap[k] - } - - return "" -} - // Ask for a profile in the command line. func sendProfile(state *client.Client, data interface{}) { fmt.Printf("Multiple VPN profiles found. Please select a profile by entering e.g. 1") @@ -102,7 +53,7 @@ func sendProfile(state *client.Client, data interface{}) { var options []string i := 0 for k, v := range sps.Map { - ps += fmt.Sprintf("\n%d - %s", i+1, GetLanguageMatched(v.DisplayName, "en")) + ps += fmt.Sprintf("\n%d - %s", i+1, util.GetLanguageMatched(v.DisplayName, "en")) options = append(options, k) i++ } diff --git a/exports/exports.go b/exports/exports.go index f651c7f7..511a1f98 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -14,47 +14,7 @@ package main /* -#include -#include - -typedef long long int (*ReadRxBytes)(); - -typedef int (*StateCB)(int oldstate, int newstate, void* data); - -typedef void (*RefreshList)(); -typedef void (*TokenGetter)(const char* server_id, int server_type, char* out, size_t len); -typedef void (*TokenSetter)(const char* server_id, int server_type, const char* tokens); -typedef void (*ProxySetup)(int fd, const char* peer_ips); -typedef void (*ProxyReady)(); - -static long long int get_read_rx_bytes(ReadRxBytes read) -{ - return read(); -} -static int call_callback(StateCB callback, int oldstate, int newstate, void* data) -{ - return callback(oldstate, newstate, data); -} -static void call_refresh_list(RefreshList refresh) -{ - refresh(); -} -static void call_token_getter(TokenGetter getter, const char* server_id, int server_type, char* out, size_t len) -{ - getter(server_id, server_type, out, len); -} -static void call_token_setter(TokenSetter setter, const char* server_id, int server_type, const char* tokens) -{ - setter(server_id, server_type, tokens); -} -static void call_proxy_setup(ProxySetup proxysetup, int fd, const char* peer_ips) -{ - proxysetup(fd, peer_ips); -} -static void call_proxy_ready(ProxyReady ready) -{ - ready(); -} +#include "exports.h" */ import "C" diff --git a/exports/exports.h b/exports/exports.h new file mode 100644 index 00000000..a31920be --- /dev/null +++ b/exports/exports.h @@ -0,0 +1,46 @@ +#ifndef EXPORTS_H +#define EXPORTS_H + +#include +#include + +typedef long long int (*ReadRxBytes)(); + +typedef int (*StateCB)(int oldstate, int newstate, void* data); + +typedef void (*RefreshList)(); +typedef void (*TokenGetter)(const char* server_id, int server_type, char* out, size_t len); +typedef void (*TokenSetter)(const char* server_id, int server_type, const char* tokens); +typedef void (*ProxySetup)(int fd, const char* peer_ips); +typedef void (*ProxyReady)(); + +static long long int get_read_rx_bytes(ReadRxBytes read) +{ + return read(); +} +static int call_callback(StateCB callback, int oldstate, int newstate, void* data) +{ + return callback(oldstate, newstate, data); +} +static void call_refresh_list(RefreshList refresh) +{ + refresh(); +} +static void call_token_getter(TokenGetter getter, const char* server_id, int server_type, char* out, size_t len) +{ + getter(server_id, server_type, out, len); +} +static void call_token_setter(TokenSetter setter, const char* server_id, int server_type, const char* tokens) +{ + setter(server_id, server_type, tokens); +} +static void call_proxy_setup(ProxySetup proxysetup, int fd, const char* peer_ips) +{ + proxysetup(fd, peer_ips); +} +static void call_proxy_ready(ProxyReady ready) +{ + ready(); +} + +#endif /* EXPORTS_H */ diff --git a/exports/exports_test.go b/exports/exports_test.go new file mode 100644 index 00000000..25d108cf --- /dev/null +++ b/exports/exports_test.go @@ -0,0 +1,21 @@ +//go:build cgotesting + +package main + +import "testing" + +func TestRegister(t *testing.T) { + testRegister(t) +} + +func TestServerList(t *testing.T) { + testServerList(t) +} + +func TestGetConfig(t *testing.T) { + testGetConfig(t) +} + +func TestLetsConnectDiscovery(t *testing.T) { + testLetsConnectDiscovery(t) +} diff --git a/exports/exports_test_wrapper.go b/exports/exports_test_wrapper.go new file mode 100644 index 00000000..6a56672f --- /dev/null +++ b/exports/exports_test_wrapper.go @@ -0,0 +1,524 @@ +//go:build cgotesting + +package main + +/* +#include "exports.h" + +extern int test_state_callback(int old, int new, char* data); +*/ +import "C" + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "net/url" + "os" + "regexp" + "strings" + "testing" + "time" + + "github.com/eduvpn/eduvpn-common/internal/test" + "github.com/eduvpn/eduvpn-common/types/error" + "github.com/eduvpn/eduvpn-common/util" +) + +func getString(in *C.char) string { + if in == nil { + return "" + } + defer FreeString(in) + return C.GoString(in) +} + +func getError(t *testing.T, gerr *C.char) string { + jsonErr := getString(gerr) + var transl err.Error + + if jsonErr == "" { + return "" + } + + jerr := json.Unmarshal([]byte(jsonErr), &transl) + if jerr != nil { + t.Fatalf("failed getting error JSON, val: %v, err: %v", jsonErr, jerr) + } + + return util.GetLanguageMatched(transl.Message, "en") +} + +// ClonedAskTransition is a clone of the struct types/server.go RequiredAskTransition +// It is cloned here to ensure that when the types API changes, the tests have to be changed as well +type ClonedAskTransition struct { + Cookie int `json:"cookie"` + Data interface{} `json:"data"` +} + +//export test_state_callback +func test_state_callback(_ C.int, new C.int, data *C.char) int32 { + // OAUTH_STARTED + // We use hardcoded values here instead of constants + // to ensure that a change in the API needs to be changed here too + if int(new) == 3 { + fakeBrowserAuth(C.GoString(data)) //nolint:errcheck + return 1 + } + // ASK_PROFILE + if int(new) == 6 { + dataS := C.GoString(data) + var tr ClonedAskTransition + jsonErr := json.Unmarshal([]byte(dataS), &tr) + if jsonErr != nil { + panic(jsonErr) + } + prS := C.CString("employees") + defer FreeString(prS) + CookieReply(C.uint64_t(tr.Cookie), prS) + return 1 + } + + return 0 +} + +func testDoRegister(t *testing.T) string { + nameS := C.CString("org.letsconnect-vpn.app.linux") + defer FreeString(nameS) + versionS := C.CString("0.0.1") + defer FreeString(versionS) + dir, err := os.MkdirTemp(os.TempDir(), "eduvpn-common-test-cgo") + if err != nil { + t.Fatalf("failed creating temp dir for state file: %v", err) + } + defer os.RemoveAll(dir) + + dirS := C.CString(dir) + defer FreeString(dirS) + + return getError(t, Register(nameS, versionS, dirS, C.StateCB(C.test_state_callback), 0)) +} + +func mustRegister(t *testing.T) { + err := testDoRegister(t) + if err != "" { + t.Fatalf("got register error: %v", err) + } +} + +func testRegister(t *testing.T) { + mustRegister(t) + defer Deregister() + err := testDoRegister(t) + if err == "" { + t.Fatalf("got no register error after double registering: %v", err) + } +} + +func fakeBrowserAuth(str string) (string, error) { + go func() { + u, err := url.Parse(str) + if err != nil { + panic(err) + } + ru, err := url.Parse(u.Query().Get("redirect_uri")) + if err != nil { + panic(err) + } + oq := u.Query() + q := ru.Query() + q.Set("state", oq.Get("state")) + q.Set("code", "fakeauthcode") + ru.RawQuery = q.Encode() + + c := http.Client{} + req, err := http.NewRequest("GET", ru.String(), nil) + if err != nil { + panic(err) + } + _, err = c.Do(req) + if err != nil { + panic(err) + } + }() + return "", nil +} + +func testServer(t *testing.T) *test.Server { + // TODO: duplicate code between this and internal/api/api_test.go + listen, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to setup listener for test server: %v", err) + } + + hps := []test.HandlerPath{ + { + Method: http.MethodGet, + Path: "/.well-known/vpn-user-portal", + Response: fmt.Sprintf(` +{ + "api": { + "http://eduvpn.org/api#3": { + "api_endpoint": "https://%[1]s/test-api-endpoint", + "authorization_endpoint": "https://%[1]s/test-authorization-endpoint", + "token_endpoint": "https://%[1]s/test-token-endpoint" + } + }, + "v": "0.0.1" +}`, listen.Addr().String()), + }, + { + Method: http.MethodPost, + Path: "/test-token-endpoint", + Response: ` +{ + "access_token": "validaccess", + "refresh_token": "validrefresh", + "expires_in": 3600 +}`, + }, + { + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: ` + +{ + "info": { + "profile_list": [ + { + "default_gateway": true, + "display_name": "Employees", + "profile_id": "employees", + "vpn_proto_list": [ + "openvpn", + "wireguard" + ] + }, + { + "default_gateway": true, + "display_name": "Other", + "profile_id": "other", + "vpn_proto_list": [ + "openvpn", + "wireguard" + ] + } + ] + } +}`, + }, + { + Method: http.MethodPost, + Path: "/test-api-endpoint/disconnect", + }, + { + Path: "/test-api-endpoint/connect", + ResponseHandler: func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + w.Header().Add("expires", time.Now().Add(4*time.Hour).Format(http.TimeFormat)) + w.Header().Add("Content-Type", "application/x-wireguard-profile") + w.WriteHeader(200) + // example from https://docs.eduvpn.org/server/v3/api.html#response_1 + resp := ` +Expires: Fri, 06 Aug 2021 03:59:59 GMT +Content-Type: application/x-wireguard-profile + +[Interface] +Address = 10.43.43.2/24, fd43::2/64 +DNS = 9.9.9.9, 2620:fe::fe + +[Peer] +PublicKey = iWAHXts9w9fQVEbA5pVriPlAYMwwEPD5XcVCZDZn1AE= +AllowedIPs = 0.0.0.0/0, ::/0 +Endpoint = vpn.example:51820` + _, err := w.Write([]byte(resp)) + if err != nil { + panic(err) + } + }, + }, + } + return test.NewServerWithHandles(hps, listen) +} + +func testServerList(t *testing.T) { + mustRegister(t) + defer Deregister() + serv := testServer(t) + defer serv.Close() + + ck := CookieNew() + defer CookieDelete(ck) + defer CookieCancel(ck) + + list := fmt.Sprintf("https://%s", serv.Listener.Addr().String()) + listS := C.CString(list) + defer FreeString(listS) + + sclient, err := serv.Client() + if err != nil { + t.Fatalf("failed to obtain server client: %v", err) + } + + // TODO: can we do this better + http.DefaultTransport = sclient.Client.Transport + + gerr := getError(t, AddServer(ck, 3, listS, nil)) + if gerr != "" { + t.Fatalf("error adding server: %v", gerr) + } + + glist, glistErr := ServerList() + glistErrS := getError(t, glistErr) + if glistErrS != "" { + t.Fatalf("error getting server list: %v", glistErrS) + } + + srvlistS := getString(glist) + want := fmt.Sprintf(`{"custom_servers":[{"display_name":{"en":"127.0.0.1"},"identifier":"%s/","profiles":{"current":""}}]}`, list) + if srvlistS != want { + t.Fatalf("server list not equal, want: %v, got: %v", want, srvlistS) + } + + remErr := getError(t, RemoveServer(3, listS)) + if remErr != "" { + t.Fatalf("got error removing server: %v", remErr) + } + remErr = getError(t, RemoveServer(3, listS)) + if remErr == "" { + t.Fatalf("got no error removing server again") + } + + glist, glistErr = ServerList() + glistErrS = getError(t, glistErr) + if glistErrS != "" { + t.Fatalf("error getting server list: %v", glistErrS) + } + + srvlistS = getString(glist) + want = "{}" + if srvlistS != want { + t.Fatalf("server list not equal, want: %v, got: %v", want, srvlistS) + } +} + +// ClonedExpiryTimes is a copy of types/server Expiry +// to ensure that when the public API is changed, this should be changed too +type ClonedExpiryTimes struct { + // StartTime is the start time of the VPN in Unix + StartTime int64 `json:"start_time"` + // EndTime is the end time of the VPN in Unix. + EndTime int64 `json:"end_time"` + // ButtonTime is the Unix time at which to start showing the renew button in the UI + ButtonTime int64 `json:"button_time"` + // CountdownTime is the Unix time at which to start showing more detailed countdown timer. + // E.g. first start with days (7 days left), and when the current time is after this time, show e.g. 9 minutes and 59 seconds + CountdownTime int64 `json:"countdown_time"` + // NotificationTimes is the slice/list of times at which to show a notification that the VPN is about to expire + NotificationTimes []int64 `json:"notification_times"` +} + +func testExpiryTimes(t *testing.T) { + exp, expErr := ExpiryTimes() + expErrS := getError(t, expErr) + if expErrS != "" { + t.Fatalf("expiry times error is not empty: %v", expErrS) + } + + expS := getString(exp) + + var et ClonedExpiryTimes + + jErr := json.Unmarshal([]byte(expS), &et) + if jErr != nil { + t.Fatalf("failed parsing expiry times as JSON: %v", jErr) + } + etu := time.Unix(et.EndTime, 0) + stu := time.Unix(et.StartTime, 0) + + between := func(label string, cand time.Time, equalS bool, equalE bool) { + if !cand.After(stu) && (!equalS || !cand.Equal(stu)) { + t.Fatalf("%s: %v, is not after start time: %v", label, cand, stu) + } + if !cand.Before(etu) && (!equalE || !cand.Equal(etu)) { + t.Fatalf("%s: %v, is after end time: %v", label, cand, etu) + } + } + + now := time.Now() + between("now", now, false, false) + btu := time.Unix(et.ButtonTime, 0) + between("button time", btu, false, false) + ctu := time.Unix(et.CountdownTime, 0) + between("countdown time", ctu, true, false) + + first := true + for _, v := range et.NotificationTimes { + curr := time.Unix(v, 0) + between("notification time", curr, false, first) + first = false + } +} + +func testSetProfileID(t *testing.T) { + prfS := C.CString("idontexist") + defer FreeString(prfS) + pErr := getError(t, SetProfileID(prfS)) + if pErr == "" { + t.Fatal("got empty error for non-existent profile") + } + prfS2 := C.CString("employees") + defer FreeString(prfS2) + pErr = getError(t, SetProfileID(prfS2)) + if pErr != "" { + t.Fatal("got error setting existent profile") + } +} + +func testRenewSession(t *testing.T) { + ck := CookieNew() + rErr := getError(t, RenewSession(ck)) + if rErr != "" { + t.Fatalf("failed renewing session: %v", rErr) + } +} + +func testCleanup(t *testing.T) { + ck := CookieNew() + defer CookieDelete(ck) + cErr := getError(t, Cleanup(ck)) + if cErr != "" { + t.Fatalf("failed cleaning up connection: %v", cErr) + } +} + +func testGetConfig(t *testing.T) { + mustRegister(t) + defer Deregister() + serv := testServer(t) + defer serv.Close() + + ck := CookieNew() + defer CookieDelete(ck) + + list := fmt.Sprintf("https://%s", serv.Listener.Addr().String()) + listS := C.CString(list) + defer FreeString(listS) + + sclient, err := serv.Client() + if err != nil { + t.Fatalf("failed to obtain server client: %v", err) + } + + // TODO: can we do this better + http.DefaultTransport = sclient.Client.Transport + + _, cfgErr := GetConfig(ck, 3, listS, 0, 0) + cfgErrS := getError(t, cfgErr) + if !strings.HasSuffix(cfgErrS, "server does not exist.") { + t.Fatalf("error does not end with 'server does not exist.': %v", cfgErrS) + } + + // add the server + addErr := getError(t, AddServer(ck, 3, listS, nil)) + if addErr != "" { + t.Fatalf("failed to add server: %v", addErr) + } + + cfg, cfgErr := GetConfig(ck, 3, listS, 0, 0) + cfgErrS = getError(t, cfgErr) + if cfgErrS != "" { + t.Fatalf("failed to get config for server: %v", cfgErrS) + } + cfgS := getString(cfg) + + // match the config with the private key in the middle + bRe := `{"config":"[Interface]\nAddress = 10.43.43.2/24, fd43::2/64\nDNS = 9.9.9.9, 2620:fe::fe\nPrivateKey = ` + aRe := `\n[Peer]\nPublicKey = iWAHXts9w9fQVEbA5pVriPlAYMwwEPD5XcVCZDZn1AE=\nAllowedIPs = 0.0.0.0/0, ::/0\nEndpoint = vpn.example:51820\n","protocol":2,"default_gateway":true,"should_failover":true}` + + // simple regex to match the key, see https://lists.zx2c4.com/pipermail/wireguard/2020-December/006222.html + re := fmt.Sprintf("%s[A-Za-z0-9+/]{42}[AEIMQUYcgkosw480]=%s", regexp.QuoteMeta(bRe), regexp.QuoteMeta(aRe)) + ok, rErr := regexp.MatchString(re, cfgS) + if rErr != nil { + t.Fatalf("failed matching regexp: %v", rErr) + } + if !ok { + t.Fatalf("VPN config does not match regex: %v", cfgS) + } + + // 7 = GotConfig + stateIn, statErr := InState(7) + statErrS := getError(t, statErr) + if statErrS != "" { + t.Fatalf("got a state error when checking if client is in state: %v", statErr) + } + if stateIn == 0 { + t.Fatal("client is not in State 7: GotConfig") + } + setState := func(in C.int) { + // set state connecting + statErr := getError(t, SetState(in)) + if statErr != "" { + t.Fatalf("failed to set state: %v, err: %v", in, statErr) + } + } + + // set connecting -> connected -> disconnecting -> disconnected + setState(8) + setState(9) + setState(10) + setState(11) + + testExpiryTimes(t) + testSetProfileID(t) + testRenewSession(t) + testCleanup(t) +} + +func testLetsConnectDiscovery(t *testing.T) { + // this registers a let's connect! client + mustRegister(t) + defer Deregister() + serv := testServer(t) + defer serv.Close() + + ck := CookieNew() + defer CookieDelete(ck) + + list := fmt.Sprintf("https://%s", serv.Listener.Addr().String()) + listS := C.CString(list) + defer FreeString(listS) + + sclient, err := serv.Client() + if err != nil { + t.Fatalf("failed to obtain server client: %v", err) + } + + // TODO: can we do this better + http.DefaultTransport = sclient.Client.Transport + + // try to add an institute access server + exptErr := fmt.Sprintf("An internal error occurred. The cause of the error is: Adding a non-custom server when the client does not use discovery is not supported, identifier: %s, type: 1.", list) + addErr := getError(t, AddServer(ck, 1, listS, nil)) + if addErr != exptErr { + t.Fatalf("failed to add server got a different error: %v, want: %v", addErr, exptErr) + } + + _, servErr := DiscoServers(ck, nil) + servErrS := getError(t, servErr) + exptErr = "An internal error occurred. The cause of the error is: Server discovery with this client ID is not supported." + if servErrS != exptErr { + t.Fatalf("discovery servers got a different error: %v, want: %v", servErrS, exptErr) + } + + _, orgErr := DiscoOrganizations(ck, nil) + orgErrS := getError(t, orgErr) + exptErr = "An internal error occurred. The cause of the error is: Organization discovery with this client ID is not supported." + if orgErrS != exptErr { + t.Fatalf("discovery organizations got a different error: %v, want: %v", orgErrS, exptErr) + } +} diff --git a/internal/api/api.go b/internal/api/api.go index e9904bdf..fe258624 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -301,6 +301,9 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto // Parse expiry expH := h.Get("expires") + if expH == "" { + return nil, errors.New("the server did not give an expires header") + } expT, err := http.ParseTime(expH) if err != nil { return nil, fmt.Errorf("failed parsing expiry time: %w", err) diff --git a/internal/api/api_test.go b/internal/api/api_test.go index fcf02e9f..2d17e961 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -177,7 +177,6 @@ func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.Ha "v": "0.0.0" } `, listen.Addr().String()), - ResponseCode: 200, }, { Path: "/test-token-endpoint", @@ -284,7 +283,6 @@ func TestAPIInfo(t *testing.T) { } } `, - ResponseCode: 200, }, info: &profiles.Info{ Info: profiles.ListInfo{ @@ -318,7 +316,6 @@ func TestAPIInfo(t *testing.T) { } } `, - ResponseCode: 200, }, info: &profiles.Info{ Info: profiles.ListInfo{ @@ -386,20 +383,18 @@ func TestAPIConnect(t *testing.T) { }{ { hp: test.HandlerPath{ - Method: http.MethodPost, - Path: "/test-api-endpoint/connect", - Response: ``, - ResponseCode: 200, + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, }, cd: nil, err: ErrNoProtocols, }, { hp: test.HandlerPath{ - Method: http.MethodPost, - Path: "/test-api-endpoint/connect", - Response: ``, - ResponseCode: 200, + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, }, cd: nil, protos: []protocol.Protocol{protocol.Unknown}, @@ -407,10 +402,9 @@ func TestAPIConnect(t *testing.T) { }, { hp: test.HandlerPath{ - Method: http.MethodPost, - Path: "/test-api-endpoint/connect", - Response: ``, - ResponseCode: 200, + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, }, cd: nil, protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard, protocol.Unknown}, diff --git a/internal/failover/monitor.go b/internal/failover/monitor.go index 5d81f220..6f3e551f 100644 --- a/internal/failover/monitor.go +++ b/internal/failover/monitor.go @@ -2,13 +2,17 @@ package failover import ( "context" - "errors" "fmt" "time" "github.com/eduvpn/eduvpn-common/internal/log" ) +type sender interface { + Read(deadline time.Time) error + Send(seq int) error +} + // The DroppedConMon is a connection monitor that checks for an increase in rx bytes in certain intervals type DroppedConMon struct { // pInterval means how the interval in which to send pings @@ -18,6 +22,9 @@ type DroppedConMon struct { // The function that reads Rx bytes // If this function returns an error, the monitor exits readRxBytes func() (int64, error) + // newPinger creates a new pinger + // This gets used in the tests to mock the Ping sender interface + newPinger func(gateway string, mtu int) (sender, error) } // NewDroppedMonitor creates a new failover monitor @@ -25,7 +32,9 @@ type DroppedConMon struct { // `pDropped` is how many pings we need to send before we deem it is dropped // `readRxBytes` is a function that gets the rx bytes from the client func NewDroppedMonitor(pingInterval time.Duration, pDropped int, readRxBytes func() (int64, error)) *DroppedConMon { - return &DroppedConMon{pInterval: pingInterval, pDropped: pDropped, readRxBytes: readRxBytes} + return &DroppedConMon{pInterval: pingInterval, pDropped: pDropped, readRxBytes: readRxBytes, newPinger: func(gateway string, mtu int) (sender, error) { + return NewPinger(gateway, mtu) + }} } // Dropped checks whether or not the connection is 'dropped' @@ -43,12 +52,12 @@ func (m *DroppedConMon) dropped(startBytes int64) (bool, error) { // This does not check Rx bytes every tick, but rather when pAlive or pDropped is reached // It returns an error if there was an invalid input or a ping was failed to be sent func (m *DroppedConMon) Start(ctx context.Context, gateway string, mtuSize int) (bool, error) { - if mtuSize <= 0 { - return false, errors.New("invalid mtu size given") + if mtuSize < mtuOverhead { + return false, fmt.Errorf("invalid MTU size given, MTU has to be at least: %v bytes", mtuOverhead) } // Create a ping struct with our mtu size - p, err := NewPinger(gateway, mtuSize) + p, err := m.newPinger(gateway, mtuSize) if err != nil { return false, err } diff --git a/internal/failover/monitor_test.go b/internal/failover/monitor_test.go new file mode 100644 index 00000000..87fb3cde --- /dev/null +++ b/internal/failover/monitor_test.go @@ -0,0 +1,106 @@ +package failover + +import ( + "context" + "errors" + "io" + "testing" + "time" + + "github.com/eduvpn/eduvpn-common/internal/test" +) + +// mockedPinger is a ping sender that always returns nil for sending +// but returns EOF for reading +type mockedPinger struct{} + +func (mp *mockedPinger) Read(_ time.Time) error { + return io.EOF +} + +func (mp *mockedPinger) Send(_ int) error { + return nil +} + +func TestMonitor(t *testing.T) { + cases := []struct { + interval time.Duration + pDropped int + readRxBytes func() (int64, error) + gateway string + mtuSize int + disableDefaults bool + mockedPinger func(gateway string, mtu int) (sender, error) + wantDropped bool + wantErr string + }{ + { + mtuSize: 1, + wantDropped: false, + wantErr: "invalid MTU size given, MTU has to be at least: 28 bytes", + }, + { + readRxBytes: func() (int64, error) { + return 0, errors.New("error test") + }, + wantDropped: false, + wantErr: "error test", + }, + // default case, not dropped + {}, + // readRxBytes always returns 0 + // still we do not want a drop because we get a pong from 127.0.0.1 + { + readRxBytes: func() (int64, error) { + return 0, nil + }, + wantDropped: false, + }, + // readRxBytes always returns 0 + // we want dropped as the mock pinger does nothing + { + readRxBytes: func() (int64, error) { + return 0, nil + }, + gateway: "127.0.0.1", + mockedPinger: func(_ string, _ int) (sender, error) { + return &mockedPinger{}, nil + }, + wantDropped: true, + }, + } + + for _, c := range cases { + var counter int64 + // some defaults + if c.interval == 0 { + c.interval = 2 * time.Second + } + if c.pDropped == 0 { + c.pDropped = 5 + } + if c.gateway == "" { + c.gateway = "127.0.0.1" + } + if c.mtuSize == 0 { + c.mtuSize = 28 + } + if c.readRxBytes == nil { + c.readRxBytes = func() (int64, error) { + defer func() { + counter++ + }() + return counter, nil + } + } + dcm := NewDroppedMonitor(c.interval, c.pDropped, c.readRxBytes) + if c.mockedPinger != nil { + dcm.newPinger = c.mockedPinger + } + dropped, err := dcm.Start(context.Background(), c.gateway, c.mtuSize) + if dropped != c.wantDropped { + t.Fatalf("dropped is not equal to want dropped, got: %v, want: %v", dropped, c.wantDropped) + } + test.AssertError(t, err, c.wantErr) + } +} diff --git a/internal/test/server.go b/internal/test/server.go index b6e03afb..2f01c06b 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -60,6 +60,9 @@ func NewServerWithHandles(hps []HandlerPath, listener net.Listener) *Server { mux := http.NewServeMux() for _, hp := range hps { hp := hp + if hp.ResponseCode == 0 { + hp.ResponseCode = 200 + } mux.HandleFunc(hp.Path, hp.HandlerFunc()) } return NewServer(mux, listener) diff --git a/util/util.go b/util/util.go index c7816df3..3e07e0f8 100644 --- a/util/util.go +++ b/util/util.go @@ -4,6 +4,7 @@ package util import ( "net" + "strings" "github.com/eduvpn/eduvpn-common/i18nerr" ) @@ -31,3 +32,53 @@ func CalculateGateway(cidr string) (string, error) { return ret.String(), nil } + +// GetLanguageMatched uses a map from language tags to strings to extract the right language given the tag +// It implements it according to https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY.md#language-matching +func GetLanguageMatched(langMap map[string]string, langTag string) string { + // If no map is given, return the empty string + if len(langMap) == 0 { + return "" + } + // Try to find the exact match + if val, ok := langMap[langTag]; ok { + return val + } + // Try to find a key that starts with the OS language setting + for k := range langMap { + if strings.HasPrefix(k, langTag) { + return langMap[k] + } + } + // Try to find a key that starts with the first part of the OS language (e.g. de-) + pts := strings.Split(langTag, "-") + // We have a "-" + if len(pts) > 1 { + for k := range langMap { + if strings.HasPrefix(k, pts[0]+"-") { + return langMap[k] + } + } + } + // search for just the language (e.g. de) + for k := range langMap { + if k == pts[0] { + return langMap[k] + } + } + + // Pick one that is deemed best, e.g. en-US or en, but note that not all languages are always available! + // We force an entry that is english exactly or with an english prefix + for k := range langMap { + if k == "en" || strings.HasPrefix(k, "en-") { + return langMap[k] + } + } + + // Otherwise just return one + for k := range langMap { + return langMap[k] + } + + return "" +} diff --git a/util/util_test.go b/util/util_test.go index 0f4888d1..c2e2a75d 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -80,3 +80,47 @@ func TestCalculateGateway(t *testing.T) { } } } + +func TestGetLanguageMatched(t *testing.T) { + // exact match + returned := GetLanguageMatched(map[string]string{"en": "test", "de": "test2"}, "en") + if returned != "test" { + t.Fatalf("Got: %s, want: %s", returned, "test") + } + + // starts with language tag + returned = GetLanguageMatched(map[string]string{"en-US-test": "test", "de": "test2"}, "en-US") + if returned != "test" { + t.Fatalf("Got: %s, want: %s", returned, "test") + } + + // starts with en- + returned = GetLanguageMatched(map[string]string{"en-UK": "test", "en": "test2"}, "en-US") + if returned != "test" { + t.Fatalf("Got: %s, want: %s", returned, "test") + } + + // exact match for en + returned = GetLanguageMatched(map[string]string{"de": "test", "en": "test2"}, "en-US") + if returned != "test2" { + t.Fatalf("Got: %s, want: %s", returned, "test2") + } + + // We default to english + returned = GetLanguageMatched(map[string]string{"es": "test", "en": "test2"}, "nl-NL") + if returned != "test2" { + t.Fatalf("Got: %s, want: %s", returned, "test2") + } + + // We default to english with a - as well + returned = GetLanguageMatched(map[string]string{"est": "test", "en-": "test2"}, "en-US") + if returned != "test2" { + t.Fatalf("Got: %s, want: %s", returned, "test2") + } + + // None found just return one + returned = GetLanguageMatched(map[string]string{"es": "test"}, "en-US") + if returned != "test" { + t.Fatalf("Got: %s, want: %s", returned, "test") + } +}