diff --git a/Makefile b/Makefile index 59e05561..e8d354be 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ clean: go clean coverage: - go test -v -coverpkg=./... -coverprofile=common.cov ./... + go test -tags=cgotesting -v -coverpkg=./... -coverprofile=common.cov ./... go tool cover -func common.cov sloc: diff --git a/exports/exports.go b/exports/exports.go index f651c7f7..511a1f98 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -14,47 +14,7 @@ package main /* -#include -#include - -typedef long long int (*ReadRxBytes)(); - -typedef int (*StateCB)(int oldstate, int newstate, void* data); - -typedef void (*RefreshList)(); -typedef void (*TokenGetter)(const char* server_id, int server_type, char* out, size_t len); -typedef void (*TokenSetter)(const char* server_id, int server_type, const char* tokens); -typedef void (*ProxySetup)(int fd, const char* peer_ips); -typedef void (*ProxyReady)(); - -static long long int get_read_rx_bytes(ReadRxBytes read) -{ - return read(); -} -static int call_callback(StateCB callback, int oldstate, int newstate, void* data) -{ - return callback(oldstate, newstate, data); -} -static void call_refresh_list(RefreshList refresh) -{ - refresh(); -} -static void call_token_getter(TokenGetter getter, const char* server_id, int server_type, char* out, size_t len) -{ - getter(server_id, server_type, out, len); -} -static void call_token_setter(TokenSetter setter, const char* server_id, int server_type, const char* tokens) -{ - setter(server_id, server_type, tokens); -} -static void call_proxy_setup(ProxySetup proxysetup, int fd, const char* peer_ips) -{ - proxysetup(fd, peer_ips); -} -static void call_proxy_ready(ProxyReady ready) -{ - ready(); -} +#include "exports.h" */ import "C" diff --git a/exports/exports.h b/exports/exports.h new file mode 100644 index 00000000..a31920be --- /dev/null +++ b/exports/exports.h @@ -0,0 +1,46 @@ +#ifndef EXPORTS_H +#define EXPORTS_H + +#include +#include + +typedef long long int (*ReadRxBytes)(); + +typedef int (*StateCB)(int oldstate, int newstate, void* data); + +typedef void (*RefreshList)(); +typedef void (*TokenGetter)(const char* server_id, int server_type, char* out, size_t len); +typedef void (*TokenSetter)(const char* server_id, int server_type, const char* tokens); +typedef void (*ProxySetup)(int fd, const char* peer_ips); +typedef void (*ProxyReady)(); + +static long long int get_read_rx_bytes(ReadRxBytes read) +{ + return read(); +} +static int call_callback(StateCB callback, int oldstate, int newstate, void* data) +{ + return callback(oldstate, newstate, data); +} +static void call_refresh_list(RefreshList refresh) +{ + refresh(); +} +static void call_token_getter(TokenGetter getter, const char* server_id, int server_type, char* out, size_t len) +{ + getter(server_id, server_type, out, len); +} +static void call_token_setter(TokenSetter setter, const char* server_id, int server_type, const char* tokens) +{ + setter(server_id, server_type, tokens); +} +static void call_proxy_setup(ProxySetup proxysetup, int fd, const char* peer_ips) +{ + proxysetup(fd, peer_ips); +} +static void call_proxy_ready(ProxyReady ready) +{ + ready(); +} + +#endif /* EXPORTS_H */ diff --git a/exports/exports_test.go b/exports/exports_test.go new file mode 100644 index 00000000..6077bbd9 --- /dev/null +++ b/exports/exports_test.go @@ -0,0 +1,13 @@ +//go:build cgotesting + +package main + +import "testing" + +func TestRegister(t *testing.T) { + testRegister(t) +} + +func TestServerList(t *testing.T) { + testServerList(t) +} diff --git a/exports/exports_wrapper.go b/exports/exports_wrapper.go new file mode 100644 index 00000000..fc05de62 --- /dev/null +++ b/exports/exports_wrapper.go @@ -0,0 +1,214 @@ +package main + +/* +#include "exports.h" + +extern int test_state_callback(int old, int new, char* data); +*/ +import "C" + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "net/url" + "os" + "testing" + + "github.com/eduvpn/eduvpn-common/internal/test" + "github.com/eduvpn/eduvpn-common/types/error" + "github.com/eduvpn/eduvpn-common/util" +) + +func getString(in *C.char) string { + if in == nil { + return "" + } + defer FreeString(in) + return C.GoString(in) +} + +func getError(t *testing.T, gerr *C.char) string { + jsonErr := getString(gerr) + var transl err.Error + + if jsonErr == "" { + return "" + } + + jerr := json.Unmarshal([]byte(jsonErr), &transl) + if jerr != nil { + t.Fatalf("failed getting error JSON, val: %v, err: %v", jsonErr, jerr) + } + + return util.GetLanguageMatched(transl.Message, "en") +} + +//export test_state_callback +func test_state_callback(_ C.int, new C.int, data *C.char) int32 { + if int(new) == 3 { + fakeBrowserAuth(C.GoString(data)) + return 1 + } + return 0 +} + +func testDoRegister(t *testing.T) string { + nameS := C.CString("org.letsconnect-vpn.app.linux") + defer FreeString(nameS) + versionS := C.CString("0.0.1") + defer FreeString(versionS) + dir, err := os.MkdirTemp(os.TempDir(), "eduvpn-common-test-cgo") + if err != nil { + t.Fatalf("failed creating temp dir for state file: %v", err) + } + defer os.RemoveAll(dir) + + dirS := C.CString(dir) + defer FreeString(dirS) + + return getError(t, Register(nameS, versionS, dirS, C.StateCB(C.test_state_callback), 0)) +} + +func mustRegister(t *testing.T) { + err := testDoRegister(t) + if err != "" { + t.Fatalf("got register error: %v", err) + } +} + +func testRegister(t *testing.T) { + mustRegister(t) + defer Deregister() + err := testDoRegister(t) + if err == "" { + t.Fatalf("got no register error after double registering: %v", err) + } +} + +func fakeBrowserAuth(str string) (string, error) { + go func() { + u, err := url.Parse(str) + if err != nil { + panic(err) + } + ru, err := url.Parse(u.Query().Get("redirect_uri")) + if err != nil { + panic(err) + } + oq := u.Query() + q := ru.Query() + q.Set("state", oq.Get("state")) + q.Set("code", "fakeauthcode") + ru.RawQuery = q.Encode() + + c := http.Client{} + req, err := http.NewRequest("GET", ru.String(), nil) + if err != nil { + panic(err) + } + _, err = c.Do(req) + if err != nil { + panic(err) + } + }() + return "", nil +} + +func testServer(t *testing.T) *test.Server { + // TODO: duplicate code between this and internal/api/api_test.go + listen, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to setup listener for test server: %v", err) + } + + hps := []test.HandlerPath{ + { + Method: http.MethodGet, + Path: "/.well-known/vpn-user-portal", + Response: fmt.Sprintf(` +{ + "api": { + "http://eduvpn.org/api#3": { + "api_endpoint": "https://%[1]s/test-api-endpoint", + "authorization_endpoint": "https://%[1]s/test-authorization-endpoint", + "token_endpoint": "https://%[1]s/test-token-endpoint" + } + }, + "v": "0.0.1" +}`, listen.Addr().String()), + }, + { + Method: http.MethodPost, + Path: "/test-token-endpoint", + Response: ` +{ + "access_token": "validaccess", + "refresh_token": "validrefresh", + "expires_in": 3600 +}`, + }, + } + return test.NewServerWithHandles(hps, listen) +} + +func testServerList(t *testing.T) { + mustRegister(t) + defer Deregister() + serv := testServer(t) + defer serv.Close() + + ck := CookieNew() + defer CookieDelete(ck) + + list := fmt.Sprintf("https://%s", serv.Listener.Addr().String()) + listS := C.CString(list) + defer FreeString(listS) + + sclient, err := serv.Client() + if err != nil { + t.Fatalf("failed to obtain server client: %v", err) + } + + // TODO: can we do this better + http.DefaultTransport = sclient.Client.Transport + + gerr := getError(t, AddServer(ck, 3, listS, nil)) + if gerr != "" { + t.Fatalf("error adding server: %v", gerr) + } + + glist, glistErr := ServerList() + glistErrS := getError(t, glistErr) + if glistErrS != "" { + t.Fatalf("error getting server list: %v", glistErrS) + } + + srvlistS := getString(glist) + want := fmt.Sprintf(`{"custom_servers":[{"display_name":{"en":"127.0.0.1"},"identifier":"%s/","profiles":{"current":""}}]}`, list) + if srvlistS != want { + t.Fatalf("server list not equal, want: %v, got: %v", want, srvlistS) + } + + remErr := getError(t, RemoveServer(3, listS)) + if remErr != "" { + t.Fatalf("got error removing server: %v", remErr) + } + remErr = getError(t, RemoveServer(3, listS)) + if remErr == "" { + t.Fatalf("got no error removing server again") + } + + glist, glistErr = ServerList() + glistErrS = getError(t, glistErr) + if glistErrS != "" { + t.Fatalf("error getting server list: %v", glistErrS) + } + + srvlistS = getString(glist) + want = "{}" + if srvlistS != want { + t.Fatalf("server list not equal, want: %v, got: %v", want, srvlistS) + } +}