diff --git a/internal/client/connutils.go b/internal/client/connutils.go index ac70c762..c3f901e6 100644 --- a/internal/client/connutils.go +++ b/internal/client/connutils.go @@ -66,6 +66,7 @@ type connConfig struct { waitUntilAvailable time.Duration tlsCAData []byte tlsSecurity string + tlsServerName string serverSettings *snc.ServerSettings secretKey string } @@ -88,6 +89,7 @@ func (c *connConfig) tlsConfig() (*tls.Config, error) { tlsConfig := &tls.Config{ RootCAs: roots, NextProtos: []string{"edgedb-binary"}, + ServerName: c.tlsServerName, } switch c.tlsSecurity { @@ -128,11 +130,11 @@ type configResolver struct { host cfgVal // string port cfgVal // int database cfgVal // string - branch cfgVal // string user cfgVal // string password cfgVal // OptionalStr tlsCAData cfgVal // []byte tlsSecurity cfgVal // string + tlsServerName cfgVal // string waitUntilAvailable cfgVal // time.Duration serverSettings *snc.ServerSettings secretKey cfgVal // string @@ -217,17 +219,6 @@ func (r *configResolver) setDatabase(val, source string) error { return nil } -func (r *configResolver) setBranch(val, source string) error { - if r.branch.val != nil { - return nil - } - if val == "" { - return errors.New(`invalid branch name: ""`) - } - r.branch = cfgVal{val: val, source: source} - return nil -} - func (r *configResolver) setUser(val, source string) error { if r.user.val != nil { return nil @@ -279,6 +270,15 @@ func (r *configResolver) setTLSSecurity(val string, source string) error { return nil } +func (r *configResolver) setTLSServerName(val string, source string) error { + if r.tlsServerName.val != nil { + return nil + } + + r.tlsServerName = cfgVal{val: val, source: source} + return nil +} + func (r *configResolver) setWaitUntilAvailable( val time.Duration, source string, @@ -354,7 +354,7 @@ func (r *configResolver) resolveOptions( } if opts.Branch != "" { - if e := r.setBranch(opts.Branch, "Branch options"); e != nil { + if e := r.setDatabase(opts.Branch, "Branch options"); e != nil { return e } } @@ -424,6 +424,14 @@ func (r *configResolver) resolveOptions( "TLSOptions.SecurityMode option") } + if opts.TLSOptions.ServerName != "" { + secSources = append(secSources, "TLSOptions.ServerName") + err = r.setTLSServerName( + opts.TLSOptions.ServerName, + "TLSOptions.ServerName options", + ) + } + if len(secSources) > 1 { return fmt.Errorf( "mutually exclusive options set in Options: %v", @@ -502,50 +510,24 @@ func (r *configResolver) resolveDSN( "cannot be present at the same time") } - if r.database.val != nil { - return fmt.Errorf( - "`branch` in DSN and %s are mutually exclusive options", - r.database.source, - ) - } - - val, err = popDSNValue(query, db, "branch", r.branch.val == nil) + val, err = popDSNValue(query, db, "branch", r.database.val == nil) if err != nil { return err } else if val.val != nil { br := strings.TrimPrefix(val.val.(string), "/") - if e := r.setBranch(br, source+val.source); e != nil { + if e := r.setDatabase(br, source+val.source); e != nil { return e } } } else { - if r.branch.val != nil { - if queryContains("database", query) { - return fmt.Errorf( - "`database` in DSN and %s are mutually exclusive options", - r.branch.source, - ) - } - - val, err = popDSNValue(query, db, "branch", r.branch.val == nil) - if err != nil { - return err - } else if val.val != nil { - br := strings.TrimPrefix(val.val.(string), "/") - if e := r.setBranch(br, source+val.source); e != nil { - return e - } - } - } else { - val, err = popDSNValue( - query, db, "database", r.database.val == nil) - if err != nil { - return err - } else if val.val != nil { - db := strings.TrimPrefix(val.val.(string), "/") - if e := r.setDatabase(db, source+val.source); e != nil { - return e - } + val, err = popDSNValue( + query, db, "database", r.database.val == nil) + if err != nil { + return err + } else if val.val != nil { + db := strings.TrimPrefix(val.val.(string), "/") + if e := r.setDatabase(db, source+val.source); e != nil { + return e } } } @@ -614,6 +596,22 @@ func (r *configResolver) resolveDSN( } } + val, err = popDSNValue( + query, + "", + "tls_server_name", + r.tlsServerName.val == nil, + ) + if err != nil { + return err + } + if val.val != nil { + err = r.setTLSServerName(val.val.(string), source+val.source) + if err != nil { + return err + } + } + val, err = popDSNValue( query, "", @@ -707,7 +705,7 @@ func (r *configResolver) applyCredentials( } if br, ok := creds.branch.Get(); ok && br != "" { - if e := r.setBranch(br, source); e != nil { + if e := r.setDatabase(br, source); e != nil { return e } } @@ -734,15 +732,21 @@ func (r *configResolver) applyCredentials( } func (r *configResolver) resolveEnvVars(paths *cfgPaths) (bool, error) { - if db, ok := os.LookupEnv("EDGEDB_DATABASE"); ok { + db, dbOk := os.LookupEnv("EDGEDB_DATABASE") + if dbOk { err := r.setDatabase(db, "EDGEDB_DATABASE environment variable") if err != nil { return false, err } } - if db, ok := os.LookupEnv("EDGEDB_BRANCH"); ok { - err := r.setBranch(db, "EDGEDB_BRANCH environment variable") + if branch, ok := os.LookupEnv("EDGEDB_BRANCH"); ok { + if dbOk { + return false, errors.New( + "mutually exclusive options EDGEDB_DATABASE and " + + "EDGEDB_BRANCH environment variables are set") + } + err := r.setDatabase(branch, "EDGEDB_BRANCH environment variable") if err != nil { return false, err } @@ -784,6 +788,16 @@ func (r *configResolver) resolveEnvVars(paths *cfgPaths) (bool, error) { } } + if val, ok := os.LookupEnv("EDGEDB_TLS_SERVER_NAME"); ok { + e := r.setTLSServerName( + val, + "EDGEDB_TLS_SERVER_NAME environment variable", + ) + if e != nil { + return false, e + } + } + if len(tlsCaSources) > 1 { return false, fmt.Errorf( "mutually exclusive environment variables set: %v", @@ -946,18 +960,8 @@ func (r *configResolver) config(opts *Options) (*connConfig, error) { database := "edgedb" branch := "__default__" if r.database.val != nil { - if r.branch.val != nil { - return nil, fmt.Errorf( - "%s and %s are mutually exclusive options", - r.database.source, - r.branch.source, - ) - } database = r.database.val.(string) branch = database - } else if r.branch.val != nil { - branch = r.branch.val.(string) - database = branch } user := "edgedb" @@ -980,6 +984,11 @@ func (r *configResolver) config(opts *Options) (*connConfig, error) { tlsSecurity = r.tlsSecurity.val.(string) } + tlsServerName := "" + if r.tlsServerName.val != nil { + tlsServerName = r.tlsServerName.val.(string) + } + secretKey := "" if r.secretKey.val != nil { secretKey = r.secretKey.val.(string) @@ -1033,6 +1042,7 @@ func (r *configResolver) config(opts *Options) (*connConfig, error) { serverSettings: r.serverSettings, tlsCAData: certData, tlsSecurity: tlsSecurity, + tlsServerName: tlsServerName, secretKey: secretKey, }, nil } @@ -1268,6 +1278,11 @@ var dsnKeyLookup = map[string][]string{ "password": {"password", "password_env", "password_file"}, "tls_ca_file": {"tls_ca_file", "tls_ca_file_env"}, "tls_security": {"tls_security", "tls_security_env", "tls_security_file"}, + "tls_server_name": { + "tls_server_name", + "tls_server_name_env", + "tls_server_name_file", + }, "tls_verify_hostname": { "tls_verify_hostname", "tls_verify_hostname_env", diff --git a/internal/client/connutils_test.go b/internal/client/connutils_test.go index 4cf20f79..89cffaf3 100644 --- a/internal/client/connutils_test.go +++ b/internal/client/connutils_test.go @@ -626,6 +626,8 @@ func TestConnectionParameterResolution(t *testing.T) { options.TLSOptions.CA = getBytes(t, opts, "tlsCA") options.TLSOptions.SecurityMode = TLSSecurityMode( getStr(t, opts, "tlsSecurity")) + options.TLSOptions.ServerName = getStr( + t, opts, "tlsServerName") if opts["serverSettings"] != nil { ss := opts["serverSettings"].(map[string]interface{}) options.ServerSettings = make(map[string][]byte, len(ss)) @@ -673,6 +675,10 @@ func TestConnectionParameterResolution(t *testing.T) { expectedResult.secretKey = key.(string) } + if key := res["tlsServerName"]; key != nil { + expectedResult.tlsServerName = key.(string) + } + ss := res["serverSettings"].(map[string]interface{}) for k, v := range ss { expectedResult.serverSettings.Set(k, []byte(v.(string))) diff --git a/internal/client/credentials.go b/internal/client/credentials.go index b3b72f50..e83a42b1 100644 --- a/internal/client/credentials.go +++ b/internal/client/credentials.go @@ -106,9 +106,11 @@ func validateCredentials(data map[string]interface{}) (*credentials, error) { result.host.Set(h) } - if inMap("database", data) && inMap("branch", data) { + if inMap("database", data) && + inMap("branch", data) && + data["database"] != data["branch"] { return nil, errors.New( - "`database` and `branch` are mutually exclusive") + "`database` and `branch` are both set but do not match") } if database, ok := data["database"]; ok { diff --git a/internal/client/options.go b/internal/client/options.go index 4c82d86b..9b35bb46 100644 --- a/internal/client/options.go +++ b/internal/client/options.go @@ -127,6 +127,8 @@ type TLSOptions struct { CAFile string // Determines how strict we are with TLS checks SecurityMode TLSSecurityMode + // Used to verify the hostname on the returned certificates + ServerName string } // TLSSecurityMode specifies how strict TLS validation is. diff --git a/shared-client-testcases b/shared-client-testcases index 4f45667e..94099c29 160000 --- a/shared-client-testcases +++ b/shared-client-testcases @@ -1 +1 @@ -Subproject commit 4f45667e5fe25bed3b1fe17ee3d93cdf52c1fa77 +Subproject commit 94099c29e0811b0fb1f662b5089fbeaaaa3d4e9f