From ff0477162b7ad2ee1d165f88c61edee57d0aaa8a Mon Sep 17 00:00:00 2001 From: Ain Ghazal Date: Tue, 11 Jun 2024 19:04:21 +0200 Subject: [PATCH] feat(openvpn): implement richer input This commit: 1. modifies `./internal/registry` and its `openvpn.go` file such that `openvpn` has its own private target loader; 2. modifies `./internal/experiment/openvpn` to use the richer input targets to merge the options for the openvpn experiment. 3. removes cache from session after fetching openvpn config --- internal/engine/session.go | 6 - internal/experiment/openvpn/endpoint.go | 70 ++---- internal/experiment/openvpn/endpoint_test.go | 78 ++---- internal/experiment/openvpn/openvpn.go | 189 ++++----------- internal/experiment/openvpn/openvpn_test.go | 222 ++---------------- internal/experiment/openvpn/richerinput.go | 146 ++++++++++++ .../experiment/openvpn/richerinput_test.go | 175 ++++++++++++++ internal/registry/openvpn.go | 5 +- internal/targetloading/targetloading.go | 6 +- 9 files changed, 415 insertions(+), 482 deletions(-) create mode 100644 internal/experiment/openvpn/richerinput.go create mode 100644 internal/experiment/openvpn/richerinput_test.go diff --git a/internal/engine/session.go b/internal/engine/session.go index 1a6bc5443b..0f41f5f61a 100644 --- a/internal/engine/session.go +++ b/internal/engine/session.go @@ -66,7 +66,6 @@ type Session struct { softwareName string softwareVersion string tempDir string - vpnConfig map[string]model.OOAPIVPNProviderConfig // closeOnce allows us to call Close just once. closeOnce sync.Once @@ -178,7 +177,6 @@ func NewSession(ctx context.Context, config SessionConfig) (*Session, error) { torArgs: config.TorArgs, torBinary: config.TorBinary, tunnelDir: config.TunnelDir, - vpnConfig: make(map[string]model.OOAPIVPNProviderConfig), } proxyURL := config.ProxyURL if proxyURL != nil { @@ -381,9 +379,6 @@ func (s *Session) FetchTorTargets( // internal cache. We do this to avoid hitting the API for every input. func (s *Session) FetchOpenVPNConfig( ctx context.Context, provider, cc string) (*model.OOAPIVPNProviderConfig, error) { - if config, ok := s.vpnConfig[provider]; ok { - return &config, nil - } clnt, err := s.newOrchestraClient(ctx) if err != nil { return nil, err @@ -397,7 +392,6 @@ func (s *Session) FetchOpenVPNConfig( if err != nil { return nil, err } - s.vpnConfig[provider] = config return &config, nil } diff --git a/internal/experiment/openvpn/endpoint.go b/internal/experiment/openvpn/endpoint.go index 6289d75cda..abd6c53459 100644 --- a/internal/experiment/openvpn/endpoint.go +++ b/internal/experiment/openvpn/endpoint.go @@ -1,12 +1,12 @@ package openvpn import ( - "encoding/base64" "errors" "fmt" "math/rand" "net" "net/url" + "slices" "strings" vpnconfig "github.com/ooni/minivpn/pkg/config" @@ -178,16 +178,6 @@ func (e endpointList) Shuffle() endpointList { return e } -// defaultOptionsByProvider is a map containing base config for -// all the known providers. We extend this base config with credentials coming -// from the OONI API. -var defaultOptionsByProvider = map[string]*vpnconfig.OpenVPNOptions{ - "riseupvpn": { - Auth: "SHA512", - Cipher: "AES-256-GCM", - }, -} - // APIEnabledProviders is the list of providers that the stable API Endpoint knows about. // This array will be a subset of the keys in defaultOptionsByProvider, but it might make sense // to still register info about more providers that the API officially knows about. @@ -196,40 +186,25 @@ var APIEnabledProviders = []string{ "riseupvpn", } -// isValidProvider returns true if the provider is found as key in the registry of defaultOptionsByProvider. -// TODO(ainghazal): consolidate with list of enabled providers from the API viewpoint. +// isValidProvider returns true if the provider is found as key in the array of APIEnabledProviders func isValidProvider(provider string) bool { - _, ok := defaultOptionsByProvider[provider] - return ok + return slices.Contains(APIEnabledProviders, provider) } -// getOpenVPNConfig gets a properly configured [*vpnconfig.Config] object for the given endpoint. -// To obtain that, we merge the endpoint specific configuration with base options. -// Base options are hardcoded for the moment, for comparability among different providers. -// We can add them to the OONI API and as extra cli options if ever needed. -func getOpenVPNConfig( +// mergeOpenVPNConfig gets a properly configured [*vpnconfig.Config] object for the given endpoint. +// To obtain that, we merge the endpoint specific configuration with the options passed as richer input targets. +func mergeOpenVPNConfig( tracer *vpntracex.Tracer, logger model.Logger, endpoint *endpoint, - creds *vpnconfig.OpenVPNOptions) (*vpnconfig.Config, error) { + config *Config) (*vpnconfig.Config, error) { + // TODO(ainghazal): use merge ability in vpnconfig.OpenVPNOptions merge (pending PR) provider := endpoint.Provider if !isValidProvider(provider) { return nil, fmt.Errorf("%w: unknown provider: %s", ErrInvalidInput, provider) } - baseOptions := defaultOptionsByProvider[provider] - - if baseOptions == nil { - return nil, fmt.Errorf("empty baseOptions for provider: %s", provider) - } - if baseOptions.Cipher == "" { - return nil, fmt.Errorf("empty cipher for provider: %s", provider) - } - if baseOptions.Auth == "" { - return nil, fmt.Errorf("empty auth for provider: %s", provider) - } - cfg := vpnconfig.NewConfig( vpnconfig.WithLogger(logger), vpnconfig.WithOpenVPNOptions( @@ -239,14 +214,13 @@ func getOpenVPNConfig( Port: endpoint.Port, Proto: vpnconfig.Proto(endpoint.Transport), - // options coming from the default known values. - Cipher: baseOptions.Cipher, - Auth: baseOptions.Auth, - - // auth coming from passed credentials. - CA: creds.CA, - Cert: creds.Cert, - Key: creds.Key, + // options and credentials come from the experiment + // richer input targets. + Cipher: config.Cipher, + Auth: config.Auth, + CA: []byte(config.SafeCA), + Cert: []byte(config.SafeCert), + Key: []byte(config.SafeKey), }, ), vpnconfig.WithHandshakeTracer(tracer), @@ -255,20 +229,6 @@ func getOpenVPNConfig( return cfg, nil } -// maybeExtractBase64Blob is used to pass credentials as command-line options. -func maybeExtractBase64Blob(val string) (string, error) { - s := strings.TrimPrefix(val, "base64:") - if len(s) == len(val) { - // no prefix, so we'll treat this as a pem-encoded credential. - return s, nil - } - dec, err := base64.URLEncoding.DecodeString(strings.TrimSpace(s)) - if err != nil { - return "", fmt.Errorf("%w: %s", ErrBadBase64Blob, err) - } - return string(dec), nil -} - func isValidProtocol(s string) bool { if strings.HasPrefix(s, "openvpn://") { return true diff --git a/internal/experiment/openvpn/endpoint_test.go b/internal/experiment/openvpn/endpoint_test.go index bc33301419..57243eeae0 100644 --- a/internal/experiment/openvpn/endpoint_test.go +++ b/internal/experiment/openvpn/endpoint_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - vpnconfig "github.com/ooni/minivpn/pkg/config" vpntracex "github.com/ooni/minivpn/pkg/tracex" ) @@ -272,7 +271,7 @@ func Test_isValidProvider(t *testing.T) { } } -func Test_getVPNConfig(t *testing.T) { +func Test_mergeVPNConfig(t *testing.T) { tracer := vpntracex.NewTracer(time.Now()) e := &endpoint{ Provider: "riseupvpn", @@ -280,13 +279,16 @@ func Test_getVPNConfig(t *testing.T) { Port: "443", Transport: "udp", } - creds := &vpnconfig.OpenVPNOptions{ - CA: []byte("ca"), - Cert: []byte("cert"), - Key: []byte("key"), + + config := &Config{ + Auth: "SHA512", + Cipher: "AES-256-GCM", + SafeCA: "ca", + SafeCert: "cert", + SafeKey: "key", } - cfg, err := getOpenVPNConfig(tracer, nil, e, creds) + cfg, err := mergeOpenVPNConfig(tracer, nil, e, config) if err != nil { t.Fatalf("did not expect error, got: %v", err) } @@ -311,18 +313,18 @@ func Test_getVPNConfig(t *testing.T) { if transport := cfg.OpenVPNOptions().Proto; string(transport) != e.Transport { t.Errorf("expected transport %s, got %s", e.Transport, transport) } - if diff := cmp.Diff(cfg.OpenVPNOptions().CA, creds.CA); diff != "" { + if diff := cmp.Diff(cfg.OpenVPNOptions().CA, []byte(config.SafeCA)); diff != "" { t.Error(diff) } - if diff := cmp.Diff(cfg.OpenVPNOptions().Cert, creds.Cert); diff != "" { + if diff := cmp.Diff(cfg.OpenVPNOptions().Cert, []byte(config.SafeCert)); diff != "" { t.Error(diff) } - if diff := cmp.Diff(cfg.OpenVPNOptions().Key, creds.Key); diff != "" { + if diff := cmp.Diff(cfg.OpenVPNOptions().Key, []byte(config.SafeKey)); diff != "" { t.Error(diff) } } -func Test_getVPNConfig_with_unknown_provider(t *testing.T) { +func Test_mergeOpenVPNConfig_with_unknown_provider(t *testing.T) { tracer := vpntracex.NewTracer(time.Now()) e := &endpoint{ Provider: "nsa", @@ -330,62 +332,18 @@ func Test_getVPNConfig_with_unknown_provider(t *testing.T) { Port: "443", Transport: "udp", } - creds := &vpnconfig.OpenVPNOptions{ - CA: []byte("ca"), - Cert: []byte("cert"), - Key: []byte("key"), + cfg := &Config{ + SafeCA: "ca", + SafeCert: "cert", + SafeKey: "key", } - _, err := getOpenVPNConfig(tracer, nil, e, creds) + _, err := mergeOpenVPNConfig(tracer, nil, e, cfg) if !errors.Is(err, ErrInvalidInput) { t.Fatalf("expected invalid input error, got: %v", err) } } -func Test_extractBase64Blob(t *testing.T) { - t.Run("decode good blob", func(t *testing.T) { - blob := "base64:dGhlIGJsdWUgb2N0b3B1cyBpcyB3YXRjaGluZw==" - decoded, err := maybeExtractBase64Blob(blob) - if decoded != "the blue octopus is watching" { - t.Fatal("could not decoded blob correctly") - } - if err != nil { - t.Fatal("should not fail with first blob") - } - }) - t.Run("try decode without prefix", func(t *testing.T) { - blob := "dGhlIGJsdWUgb2N0b3B1cyBpcyB3YXRjaGluZw==" - dec, err := maybeExtractBase64Blob(blob) - if err != nil { - t.Fatal("should fail without prefix") - } - if dec != blob { - t.Fatal("decoded should be the same") - } - }) - t.Run("bad base64 blob should fail", func(t *testing.T) { - blob := "base64:dGhlIGJsdWUgb2N0b3B1cyBpcyB3YXRjaGluZw" - _, err := maybeExtractBase64Blob(blob) - if !errors.Is(err, ErrBadBase64Blob) { - t.Fatal("bad blob should fail without prefix") - } - }) - t.Run("decode empty blob", func(t *testing.T) { - blob := "base64:" - _, err := maybeExtractBase64Blob(blob) - if err != nil { - t.Fatal("empty blob should not fail") - } - }) - t.Run("illegal base64 data should fail", func(t *testing.T) { - blob := "base64:==" - _, err := maybeExtractBase64Blob(blob) - if !errors.Is(err, ErrBadBase64Blob) { - t.Fatal("bad base64 data should fail") - } - }) -} - func Test_IsValidProtocol(t *testing.T) { t.Run("openvpn is valid", func(t *testing.T) { if !isValidProtocol("openvpn://foobar.bar") { diff --git a/internal/experiment/openvpn/openvpn.go b/internal/experiment/openvpn/openvpn.go index 2d2af0e01b..db87cd558e 100644 --- a/internal/experiment/openvpn/openvpn.go +++ b/internal/experiment/openvpn/openvpn.go @@ -4,13 +4,12 @@ package openvpn import ( "context" "errors" - "fmt" "strconv" - "strings" "time" "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/targetloading" vpnconfig "github.com/ooni/minivpn/pkg/config" vpntracex "github.com/ooni/minivpn/pkg/tracex" @@ -18,8 +17,19 @@ import ( ) const ( - testVersion = "0.1.2" - openVPNProcol = "openvpn" + testName = "openvpn" + testVersion = "0.1.3" + openVPNProtocol = "openvpn" +) + +// errors are in addition to any other errors returned by the low level packages +// that are used by this experiment to implement its functionality. +var ( + // ErrInputRequired is returned when the experiment is not passed any input. + ErrInputRequired = targetloading.ErrInputRequired + + // ErrInvalidInput is returned if we failed to parse the input to obtain an endpoint we can measure. + ErrInvalidInput = errors.New("invalid input") ) // Config contains the experiment config. @@ -93,18 +103,16 @@ func (tk *TestKeys) AllConnectionsSuccessful() bool { // Measurer performs the measurement. type Measurer struct { - config Config - testName string } // NewExperimentMeasurer creates a new ExperimentMeasurer. -func NewExperimentMeasurer(config Config, testName string) model.ExperimentMeasurer { - return Measurer{config: config, testName: testName} +func NewExperimentMeasurer() model.ExperimentMeasurer { + return Measurer{} } // ExperimentName implements model.ExperimentMeasurer.ExperimentName. func (m Measurer) ExperimentName() string { - return m.testName + return testName } // ExperimentVersion implements model.ExperimentMeasurer.ExperimentVersion. @@ -112,19 +120,14 @@ func (m Measurer) ExperimentVersion() string { return testVersion } -var ( - // ErrInvalidInput is returned if we failed to parse the input to obtain an endpoint we can measure. - ErrInvalidInput = errors.New("invalid input") -) - -func parseEndpoint(m *model.Measurement) (*endpoint, error) { - if m.Input != "" { - if ok := isValidProtocol(string(m.Input)); !ok { - return nil, ErrInvalidInput - } - return newEndpointFromInputString(string(m.Input)) +func parseEndpoint(input string) (*endpoint, error) { + if input == "" { + return nil, ErrInputRequired + } + if ok := isValidProtocol(input); !ok { + return nil, ErrInvalidInput } - return nil, fmt.Errorf("%w: %s", ErrInvalidInput, "input is mandatory") + return newEndpointFromInputString(input) } // AuthMethod is the authentication method used by a provider. @@ -138,130 +141,6 @@ var ( AuthUserPass = AuthMethod("userpass") ) -var providerAuthentication = map[string]AuthMethod{ - "riseup": AuthCertificate, - "tunnelbear": AuthUserPass, - "surfshark": AuthUserPass, -} - -func hasCredentialsInOptions(cfg Config, method AuthMethod) bool { - switch method { - case AuthCertificate: - ok := cfg.SafeCA != "" && cfg.SafeCert != "" && cfg.SafeKey != "" - return ok - default: - return false - } -} - -// MaybeGetCredentialsFromOptions overrides authentication info with what user provided in options. -// Each certificate/key can be encoded in base64 so that a single option can be safely represented as command line options. -// This function returns no error if there are no credentials in the passed options, only if failing to parse them. -func MaybeGetCredentialsFromOptions(cfg Config, opts *vpnconfig.OpenVPNOptions, method AuthMethod) (bool, error) { - if ok := hasCredentialsInOptions(cfg, method); !ok { - return false, nil - } - ca, err := maybeExtractBase64Blob(cfg.SafeCA) - if err != nil { - return false, err - } - opts.CA = []byte(ca) - - key, err := maybeExtractBase64Blob(cfg.SafeKey) - if err != nil { - return false, err - } - opts.Key = []byte(key) - - cert, err := maybeExtractBase64Blob(cfg.SafeCert) - if err != nil { - return false, err - } - opts.Cert = []byte(cert) - return true, nil -} - -func (m *Measurer) getCredentialsFromAPI( - ctx context.Context, - sess model.ExperimentSession, - provider string, - opts *vpnconfig.OpenVPNOptions) error { - // We expect the credentials from the API response to be encoded as the direct PEM serialization. - apiCreds, err := m.FetchProviderCredentials(ctx, sess, provider) - // TODO(ainghazal): validate credentials have the info we expect, certs are not expired etc. - if err != nil { - sess.Logger().Warnf("Error fetching credentials from API: %s", err.Error()) - return err - } - sess.Logger().Infof("Got credentials from provider: %s", provider) - - opts.CA = []byte(apiCreds.Config.CA) - opts.Cert = []byte(apiCreds.Config.Cert) - opts.Key = []byte(apiCreds.Config.Key) - return nil -} - -// GetCredentialsFromOptionsOrAPI attempts to find valid credentials for the given provider, either -// from the passed Options (cli, oonirun), or from a remote call to the OONI API endpoint. -func (m *Measurer) GetCredentialsFromOptionsOrAPI( - ctx context.Context, - sess model.ExperimentSession, - provider string) (*vpnconfig.OpenVPNOptions, error) { - - method, ok := providerAuthentication[strings.TrimSuffix(provider, "vpn")] - if !ok { - return nil, fmt.Errorf("%w: provider auth unknown: %s", ErrInvalidInput, provider) - } - - // Empty options object to fill with credentials. - creds := &vpnconfig.OpenVPNOptions{} - - switch method { - case AuthCertificate: - ok, err := MaybeGetCredentialsFromOptions(m.config, creds, method) - if err != nil { - return nil, err - } - if ok { - return creds, nil - } - // No options passed, so let's get the credentials that inputbuilder should have cached - // for us after hitting the OONI API. - if err := m.getCredentialsFromAPI(ctx, sess, provider, creds); err != nil { - return nil, err - } - return creds, nil - - default: - return nil, fmt.Errorf("%w: method not implemented (%s)", ErrInvalidInput, method) - } - -} - -// mergeOpenVPNConfig attempts to get credentials from Options or API, and then -// constructs a [*vpnconfig.Config] instance after merging the credentials passed by options or API response. -// It also returns an error if the operation fails. -func (m *Measurer) mergeOpenVPNConfig( - ctx context.Context, - sess model.ExperimentSession, - endpoint *endpoint, - tracer *vpntracex.Tracer) (*vpnconfig.Config, error) { - - logger := sess.Logger() - - credentials, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, endpoint.Provider) - if err != nil { - return nil, err - } - - openvpnConfig, err := getOpenVPNConfig(tracer, logger, endpoint, credentials) - if err != nil { - return nil, err - } - // TODO(ainghazal): sanity check (Remote, Port, Proto etc + missing certs) - return openvpnConfig, nil -} - // connectAndHandshake dials a connection and attempts an OpenVPN handshake using that dialer. func (m *Measurer) connectAndHandshake( ctx context.Context, @@ -344,9 +223,9 @@ func (m Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { measurement := args.Measurement sess := args.Session - endpoint, err := parseEndpoint(measurement) - if err != nil { - return err + // 0. obtain the richer input target, config, and input or panic + if args.Target == nil { + return ErrInputRequired } tk := NewTestKeys() @@ -355,10 +234,21 @@ func (m Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { idx := int64(1) handshakeTracer := vpntracex.NewTracerWithTransactionID(zeroTime, idx) - openvpnConfig, err := m.mergeOpenVPNConfig(ctx, sess, endpoint, handshakeTracer) + // build the input + target := args.Target.(*Target) + config, input := target.Options, target.URL + sess.Logger().Infof("openvpn: using richer input: %+v", input) + + endpoint, err := parseEndpoint(input) if err != nil { return err } + + openvpnConfig, err := mergeOpenVPNConfig(handshakeTracer, sess.Logger(), endpoint, config) + if err != nil { + return err + } + sess.Logger().Infof("Probing endpoint %s", endpoint.String()) connResult := m.connectAndHandshake(ctx, zeroTime, idx, sess.Logger(), endpoint, openvpnConfig, handshakeTracer) @@ -366,6 +256,7 @@ func (m Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { tk.Success = tk.AllConnectionsSuccessful() callbacks.OnProgress(1.0, "All endpoints probed") + measurement.TestKeys = tk // TODO(ainghazal): validate we have valid config for each endpoint. diff --git a/internal/experiment/openvpn/openvpn_test.go b/internal/experiment/openvpn/openvpn_test.go index ed0f2b9adb..50c084665f 100644 --- a/internal/experiment/openvpn/openvpn_test.go +++ b/internal/experiment/openvpn/openvpn_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - vpnconfig "github.com/ooni/minivpn/pkg/config" vpntracex "github.com/ooni/minivpn/pkg/tracex" "github.com/ooni/probe-cli/v3/internal/experiment/openvpn" "github.com/ooni/probe-cli/v3/internal/mocks" @@ -35,11 +34,11 @@ func makeMockSession() *mocks.Session { } func TestNewExperimentMeasurer(t *testing.T) { - m := openvpn.NewExperimentMeasurer(openvpn.Config{}, "openvpn") + m := openvpn.NewExperimentMeasurer() if m.ExperimentName() != "openvpn" { t.Fatal("invalid ExperimentName") } - if m.ExperimentVersion() != "0.1.2" { + if m.ExperimentVersion() != "0.1.3" { t.Fatal("invalid ExperimentVersion") } } @@ -60,194 +59,6 @@ func TestNewTestKeys(t *testing.T) { } } -func TestMaybeGetCredentialsFromOptions(t *testing.T) { - t.Run("cert auth returns false if cert, key and ca are not all provided", func(t *testing.T) { - cfg := openvpn.Config{ - SafeCA: "base64:Zm9v", - SafeCert: "base64:Zm9v", - } - ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, &vpnconfig.OpenVPNOptions{}, openvpn.AuthCertificate) - if err != nil { - t.Fatal("should not raise error") - } - if ok { - t.Fatal("expected false") - } - }) - t.Run("cert auth returns ok if cert, key and ca are all provided", func(t *testing.T) { - cfg := openvpn.Config{ - SafeCA: "base64:Zm9v", - SafeCert: "base64:Zm9v", - SafeKey: "base64:Zm9v", - } - opts := &vpnconfig.OpenVPNOptions{} - ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, opts, openvpn.AuthCertificate) - if err != nil { - t.Fatalf("expected err=nil, got %v", err) - } - if !ok { - t.Fatal("expected true") - } - if diff := cmp.Diff(opts.CA, []byte("foo")); diff != "" { - t.Fatal(diff) - } - if diff := cmp.Diff(opts.Cert, []byte("foo")); diff != "" { - t.Fatal(diff) - } - if diff := cmp.Diff(opts.Key, []byte("foo")); diff != "" { - t.Fatal(diff) - } - }) - t.Run("cert auth returns false and error if CA base64 is bad blob", func(t *testing.T) { - cfg := openvpn.Config{ - SafeCA: "base64:Zm9vaaa", - SafeCert: "base64:Zm9v", - SafeKey: "base64:Zm9v", - } - opts := &vpnconfig.OpenVPNOptions{} - ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, opts, openvpn.AuthCertificate) - if ok { - t.Fatal("expected false") - } - if !errors.Is(err, openvpn.ErrBadBase64Blob) { - t.Fatalf("expected err=ErrBase64Blob, got %v", err) - } - }) - t.Run("cert auth returns false and error if key base64 is bad blob", func(t *testing.T) { - cfg := openvpn.Config{ - SafeCA: "base64:Zm9v", - SafeCert: "base64:Zm9v", - SafeKey: "base64:Zm9vaaa", - } - opts := &vpnconfig.OpenVPNOptions{} - ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, opts, openvpn.AuthCertificate) - if ok { - t.Fatal("expected false") - } - if !errors.Is(err, openvpn.ErrBadBase64Blob) { - t.Fatalf("expected err=ErrBase64Blob, got %v", err) - } - }) - t.Run("cert auth returns false and error if cert base64 is bad blob", func(t *testing.T) { - cfg := openvpn.Config{ - SafeCA: "base64:Zm9v", - SafeCert: "base64:Zm9vaaa", - SafeKey: "base64:Zm9v", - } - opts := &vpnconfig.OpenVPNOptions{} - ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, opts, openvpn.AuthCertificate) - if ok { - t.Fatal("expected false") - } - if !errors.Is(err, openvpn.ErrBadBase64Blob) { - t.Fatalf("expected err=ErrBase64Blob, got %v", err) - } - }) - t.Run("userpass auth returns error, not yet implemented", func(t *testing.T) { - cfg := openvpn.Config{} - ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, &vpnconfig.OpenVPNOptions{}, openvpn.AuthUserPass) - if ok { - t.Fatal("expected false") - } - if err != nil { - t.Fatalf("expected err=nil, got %v", err) - } - }) -} - -func TestGetCredentialsFromOptionsOrAPI(t *testing.T) { - t.Run("non-registered provider raises error", func(t *testing.T) { - m := openvpn.NewExperimentMeasurer(openvpn.Config{}, "openvpn").(openvpn.Measurer) - ctx := context.Background() - sess := makeMockSession() - opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "nsa") - if !errors.Is(err, openvpn.ErrInvalidInput) { - t.Fatalf("expected err=ErrInvalidInput, got %v", err) - } - if opts != nil { - t.Fatal("expected opts=nil") - } - }) - t.Run("providers with userpass auth method raise error, not yet implemented", func(t *testing.T) { - m := openvpn.NewExperimentMeasurer(openvpn.Config{}, "openvpn").(openvpn.Measurer) - ctx := context.Background() - sess := makeMockSession() - opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "tunnelbear") - if !errors.Is(err, openvpn.ErrInvalidInput) { - t.Fatalf("expected err=ErrInvalidInput, got %v", err) - } - if opts != nil { - t.Fatal("expected opts=nil") - } - }) - t.Run("known cert auth provider and creds in options is ok", func(t *testing.T) { - config := openvpn.Config{ - SafeCA: "base64:Zm9v", - SafeCert: "base64:Zm9v", - SafeKey: "base64:Zm9v", - } - m := openvpn.NewExperimentMeasurer(config, "openvpn").(openvpn.Measurer) - ctx := context.Background() - sess := makeMockSession() - opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "riseup") - if err != nil { - t.Fatalf("expected err=nil, got %v", err) - } - if opts == nil { - t.Fatal("expected non-nil options") - } - }) - t.Run("known cert auth provider and bad creds in options returns error", func(t *testing.T) { - config := openvpn.Config{ - SafeCA: "base64:Zm9v", - SafeCert: "base64:Zm9v", - SafeKey: "base64:Zm9vaaa", - } - m := openvpn.NewExperimentMeasurer(config, "openvpn").(openvpn.Measurer) - ctx := context.Background() - sess := makeMockSession() - opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "riseup") - if !errors.Is(err, openvpn.ErrBadBase64Blob) { - t.Fatalf("expected err=ErrBadBase64, got %v", err) - } - if opts != nil { - t.Fatal("expected nil opts") - } - }) - t.Run("known cert auth provider with null options hits the api", func(t *testing.T) { - config := openvpn.Config{} - m := openvpn.NewExperimentMeasurer(config, "openvpn").(openvpn.Measurer) - ctx := context.Background() - sess := makeMockSession() - opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "riseup") - if err != nil { - t.Fatalf("expected err=nil, got %v", err) - } - if opts == nil { - t.Fatalf("expected not-nil options, got %v", opts) - } - }) - t.Run("known cert auth provider with null options hits the api and raises error if api fails", func(t *testing.T) { - config := openvpn.Config{} - m := openvpn.NewExperimentMeasurer(config, "openvpn").(openvpn.Measurer) - ctx := context.Background() - - someError := errors.New("some error") - sess := makeMockSession() - sess.MockFetchOpenVPNConfig = func(context.Context, string, string) (*model.OOAPIVPNProviderConfig, error) { - return nil, someError - } - - opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "riseup") - if !errors.Is(err, someError) { - t.Fatalf("expected err=someError, got %v", err) - } - if opts != nil { - t.Fatalf("expected nil options, got %v", opts) - } - }) -} - func TestAddConnectionTestKeys(t *testing.T) { t.Run("append tcp connection result to empty keys", func(t *testing.T) { tk := openvpn.NewTestKeys() @@ -364,17 +175,20 @@ func TestAllConnectionsSuccessful(t *testing.T) { }) } -func TestBadInputFailure(t *testing.T) { - m := openvpn.NewExperimentMeasurer(openvpn.Config{}, "openvpn") +func TestBadTargetURLFailure(t *testing.T) { + m := openvpn.NewExperimentMeasurer() ctx := context.Background() sess := makeMockSession() callbacks := model.NewPrinterCallbacks(sess.Logger()) measurement := new(model.Measurement) - measurement.Input = "openvpn://badprovider/?address=aa" args := &model.ExperimentArgs{ Callbacks: callbacks, Measurement: measurement, Session: sess, + Target: &openvpn.Target{ + URL: "openvpn://badprovider/?address=aa", + Options: &openvpn.Config{}, + }, } err := m.Run(ctx, args) if !errors.Is(err, openvpn.ErrInvalidInput) { @@ -391,9 +205,7 @@ func TestVPNInput(t *testing.T) { func TestMeasurer_FetchProviderCredentials(t *testing.T) { t.Run("Measurer.FetchProviderCredentials calls method in session", func(t *testing.T) { - m := openvpn.NewExperimentMeasurer( - openvpn.Config{}, - "openvpn").(openvpn.Measurer) + m := openvpn.NewExperimentMeasurer().(openvpn.Measurer) sess := makeMockSession() _, err := m.FetchProviderCredentials( @@ -406,9 +218,7 @@ func TestMeasurer_FetchProviderCredentials(t *testing.T) { t.Run("Measurer.FetchProviderCredentials raises error if API calls fail", func(t *testing.T) { someError := errors.New("unexpected") - m := openvpn.NewExperimentMeasurer( - openvpn.Config{}, - "openvpn").(openvpn.Measurer) + m := openvpn.NewExperimentMeasurer().(openvpn.Measurer) sess := makeMockSession() sess.MockFetchOpenVPNConfig = func(context.Context, string, string) (*model.OOAPIVPNProviderConfig, error) { @@ -424,21 +234,19 @@ func TestMeasurer_FetchProviderCredentials(t *testing.T) { } func TestSuccess(t *testing.T) { - m := openvpn.NewExperimentMeasurer(openvpn.Config{ - Provider: "riseup", - SafeCA: "base64:Zm9v", - SafeKey: "base64:Zm9v", - SafeCert: "base64:Zm9v", - }, "openvpn") + m := openvpn.NewExperimentMeasurer() ctx := context.Background() sess := makeMockSession() callbacks := model.NewPrinterCallbacks(sess.Logger()) measurement := new(model.Measurement) - measurement.Input = "openvpn://riseupvpn.corp/?address=127.0.0.1:9989&transport=tcp" args := &model.ExperimentArgs{ Callbacks: callbacks, Measurement: measurement, Session: sess, + Target: &openvpn.Target{ + URL: "openvpn://riseupvpn.corp/?address=127.0.0.1:9989&transport=tcp", + Options: &openvpn.Config{}, + }, } err := m.Run(ctx, args) if err != nil { diff --git a/internal/experiment/openvpn/richerinput.go b/internal/experiment/openvpn/richerinput.go new file mode 100644 index 0000000000..e3262e1126 --- /dev/null +++ b/internal/experiment/openvpn/richerinput.go @@ -0,0 +1,146 @@ +package openvpn + +import ( + "context" + "fmt" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/reflectx" + "github.com/ooni/probe-cli/v3/internal/targetloading" +) + +// defaultProvider is the provider we will request from API in case we got no provider set +// in the CLI options. +var defaultProvider = "riseupvpn" + +// providerAuthentication is a map so that we know which kind of credentials we +// need to fill in the openvpn options for each known provider. +var providerAuthentication = map[string]AuthMethod{ + "riseupvpn": AuthCertificate, + "tunnelbearvpn": AuthUserPass, + "surfsharkvpn": AuthUserPass, +} + +// Target is a richer-input target that this experiment should measure. +type Target struct { + // Options contains the configuration. + Options *Config + + // URL is the input URL. + URL string +} + +var _ model.ExperimentTarget = &Target{} + +// Category implements [model.ExperimentTarget]. +func (t *Target) Category() string { + return model.DefaultCategoryCode +} + +// Country implements [model.ExperimentTarget]. +func (t *Target) Country() string { + return model.DefaultCountryCode +} + +// Input implements [model.ExperimentTarget]. +func (t *Target) Input() string { + return t.URL +} + +// String implements [model.ExperimentTarget]. +func (t *Target) String() string { + return t.URL +} + +// NewLoader constructs a new [model.ExperimentTargerLoader] instance. +// +// This function PANICS if options is not an instance of [*openvpn.Config]. +func NewLoader(loader *targetloading.Loader, gopts any) model.ExperimentTargetLoader { + // Panic if we cannot convert the options to the expected type. + // + // We do not expect a panic here because the type is managed by the registry package. + options := gopts.(*Config) + + // Construct the proper loader instance. + return &targetLoader{ + loader: loader, + options: options, + session: loader.Session, + } +} + +// targetLoader loads targets for this experiment. +type targetLoader struct { + loader *targetloading.Loader + options *Config + session targetloading.Session +} + +// Load implements model.ExperimentTargetLoader. +func (tl *targetLoader) Load(ctx context.Context) ([]model.ExperimentTarget, error) { + // If inputs and files are all empty and there are no options, let's use the backend + if len(tl.loader.StaticInputs) <= 0 && len(tl.loader.SourceFiles) <= 0 && + reflectx.StructOrStructPtrIsZero(tl.options) { + return tl.loadFromBackend(ctx) + } + + // Otherwise, attempt to load the static inputs from CLI and files + inputs, err := targetloading.LoadStatic(tl.loader) + + // Handle the case where we couldn't load from CLI or files + if err != nil { + return nil, err + } + + // Build the list of targets that we should measure. + var targets []model.ExperimentTarget + for _, input := range inputs { + targets = append(targets, &Target{ + Options: tl.options, + URL: input, + }) + } + return targets, nil +} + +// TODO(ainghazal): we might want to get both the BaseURL and the HTTPClient from the session, +// and then deal with the openvpn-specific API calls ourselves within the boundaries of the experiment. +func (tl *targetLoader) loadFromBackend(_ context.Context) ([]model.ExperimentTarget, error) { + if tl.options.Provider == "" { + tl.options.Provider = defaultProvider + } + + targets := make([]model.ExperimentTarget, 0) + provider := tl.options.Provider + + // TODO(ainghazal): pass country code too (from session?) + apiConfig, err := tl.session.FetchOpenVPNConfig(context.Background(), provider, "XX") + if err != nil { + return nil, err + } + + auth, ok := providerAuthentication[provider] + if !ok { + return nil, fmt.Errorf("%w: unknown authentication for provider %s", ErrInvalidInput, provider) + } + + for _, input := range apiConfig.Inputs { + config := &Config{ + // Auth and Cipher are hardcoded for now. + Auth: "SHA512", + Cipher: "AES-256-GCM", + } + switch auth { + case AuthCertificate: + config.SafeCA = apiConfig.Config.CA + config.SafeCert = apiConfig.Config.Cert + config.SafeKey = apiConfig.Config.Key + } + targets = append(targets, &Target{ + URL: input, + Options: config, + }) + } + + return targets, nil +} diff --git a/internal/experiment/openvpn/richerinput_test.go b/internal/experiment/openvpn/richerinput_test.go new file mode 100644 index 0000000000..a9b74b1c84 --- /dev/null +++ b/internal/experiment/openvpn/richerinput_test.go @@ -0,0 +1,175 @@ +package openvpn + +import ( + "context" + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/mocks" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/targetloading" +) + +func TestTarget(t *testing.T) { + target := &Target{ + URL: "openvpn://unknown.corp?address=1.1.1.1%3A443&transport=udp", + Options: &Config{ + Auth: "SHA512", + Cipher: "AES-256-GCM", + Provider: "unknown", + SafeKey: "aa", + SafeCert: "bb", + SafeCA: "cc", + }, + } + + t.Run("Category", func(t *testing.T) { + if target.Category() != model.DefaultCategoryCode { + t.Fatal("invalid Category") + } + }) + + t.Run("Country", func(t *testing.T) { + if target.Country() != model.DefaultCountryCode { + t.Fatal("invalid Country") + } + }) + + t.Run("Input", func(t *testing.T) { + if target.Input() != "openvpn://unknown.corp?address=1.1.1.1%3A443&transport=udp" { + t.Fatal("invalid Input") + } + }) + + t.Run("String", func(t *testing.T) { + if target.String() != "openvpn://unknown.corp?address=1.1.1.1%3A443&transport=udp" { + t.Fatal("invalid String") + } + }) +} + +func TestNewLoader(t *testing.T) { + // create the pointers we expect to see + child := &targetloading.Loader{} + options := &Config{} + + // create the loader and cast it to its private type + loader := NewLoader(child, options).(*targetLoader) + + // make sure the loader is okay + if child != loader.loader { + t.Fatal("invalid loader pointer") + } + + // make sure the options are okay + if options != loader.options { + t.Fatal("invalid options pointer") + } +} + +func TestTargetLoaderLoad(t *testing.T) { + // testcase is a test case implemented by this function + type testcase struct { + // name is the test case name + name string + + // options contains the options to use + options *Config + + // loader is the loader to use + loader *targetloading.Loader + + // expectErr is the error we expect + expectErr error + + // expectResults contains the expected results + expectTargets []model.ExperimentTarget + } + + cases := []testcase{ + + { + name: "with options and inputs", + options: &Config{ + SafeCA: "aa", + SafeCert: "bb", + SafeKey: "cc", + Provider: "unknown", + }, + loader: &targetloading.Loader{ + ExperimentName: "openvpn", + InputPolicy: model.InputOrQueryBackend, + Logger: model.DiscardLogger, + Session: &mocks.Session{}, + StaticInputs: []string{ + "openvpn://unknown.corp/1.1.1.1", + }, + }, + expectErr: nil, + expectTargets: []model.ExperimentTarget{ + &Target{ + URL: "openvpn://unknown.corp/1.1.1.1", + Options: &Config{ + Provider: "unknown", + SafeCA: "aa", + SafeCert: "bb", + SafeKey: "cc", + }, + }, + }, + }, + { + name: "with just options", + options: &Config{ + Provider: "riseupvpn", + }, + loader: &targetloading.Loader{ + ExperimentName: "openvpn", + InputPolicy: model.InputOrQueryBackend, + Logger: model.DiscardLogger, + Session: &mocks.Session{}, + StaticInputs: []string{}, + SourceFiles: []string{}, + }, + expectErr: nil, + expectTargets: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // create a target loader using the given config + // + // note that we use a default test input for results predictability + // since the static list may change over time + tl := &targetLoader{ + loader: tc.loader, + options: tc.options, + } + + // load targets + targets, err := tl.Load(context.Background()) + + // make sure error is consistent + switch { + case err == nil && tc.expectErr == nil: + // fallthrough + + case err != nil && tc.expectErr != nil: + if !errors.Is(err, tc.expectErr) { + t.Fatal("unexpected error", err) + } + // fallthrough + + default: + t.Fatal("expected", tc.expectErr, "got", err) + } + + // make sure the targets are consistent + if diff := cmp.Diff(tc.expectTargets, targets); diff != "" { + t.Fatal(diff) + } + }) + } +} diff --git a/internal/registry/openvpn.go b/internal/registry/openvpn.go index 8bf107630c..ea6c6be765 100644 --- a/internal/registry/openvpn.go +++ b/internal/registry/openvpn.go @@ -14,15 +14,14 @@ func init() { AllExperiments[canonicalName] = func() *Factory { return &Factory{ build: func(config interface{}) model.ExperimentMeasurer { - return openvpn.NewExperimentMeasurer( - *config.(*openvpn.Config), "openvpn", - ) + return openvpn.NewExperimentMeasurer() }, canonicalName: canonicalName, config: &openvpn.Config{}, enabledByDefault: true, interruptible: true, inputPolicy: model.InputOrQueryBackend, + newLoader: openvpn.NewLoader, } } } diff --git a/internal/targetloading/targetloading.go b/internal/targetloading/targetloading.go index c6c91162c5..4bf9f31b27 100644 --- a/internal/targetloading/targetloading.go +++ b/internal/targetloading/targetloading.go @@ -10,7 +10,8 @@ import ( "net/url" "github.com/apex/log" - "github.com/ooni/probe-cli/v3/internal/experiment/openvpn" + // FIXME - move this to the experiment + // "github.com/ooni/probe-cli/v3/internal/experiment/openvpn" "github.com/ooni/probe-cli/v3/internal/experimentname" "github.com/ooni/probe-cli/v3/internal/fsx" "github.com/ooni/probe-cli/v3/internal/model" @@ -371,7 +372,8 @@ func (il *Loader) loadRemoteOpenVPN(ctx context.Context) ([]model.ExperimentTarg // The openvpn experiment contains an array of the providers that the API knows about. // We try to get all the remotes from the API for the list of enabled providers. - for _, provider := range openvpn.APIEnabledProviders { + providers := []string{} + for _, provider := range providers { // fetchOpenVPNConfig ultimately uses an internal cache in the session to avoid // hitting the API too many times. reply, err := il.fetchOpenVPNConfig(ctx, provider)