diff --git a/common/config/config.go b/common/config/config.go index 207bb35d0dc..85b3917cd18 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -81,6 +81,9 @@ type ( PProf struct { // Port is the port on which the PProf will bind to Port int `yaml:"port"` + // Host defaults to `localhost` but can be overriden + // for instance in the case of dual stack IPv4/IPv6 + Host string `yaml:"host"` } // RPC contains the rpc config items @@ -90,9 +93,11 @@ type ( // Port used for membership listener MembershipPort int `yaml:"membershipPort"` // BindOnLocalHost is true if localhost is the bind address + // if neither BindOnLocalHost nor BindOnIP are set then an + // an attempt to discover an address is made BindOnLocalHost bool `yaml:"bindOnLocalHost"` - // BindOnIP can be used to bind service on specific ip (eg. `0.0.0.0`) - - // check net.ParseIP for supported syntax, only IPv4 is supported, + // BindOnIP can be used to bind service on specific ip (eg. `0.0.0.0` or `::`) + // check net.ParseIP for supported syntax // mutually exclusive with `BindOnLocalHost` option BindOnIP string `yaml:"bindOnIP"` // HTTPPort is the port on which HTTP will listen. If unset/0, HTTP will be @@ -225,8 +230,8 @@ type ( // MaxJoinDuration is the max wait time to join the gossip ring MaxJoinDuration time.Duration `yaml:"maxJoinDuration"` // BroadcastAddress is used as the address that is communicated to remote nodes to connect on. - // This is generally used when BindOnIP would be the same across several nodes (ie: 0.0.0.0) - // and for nat traversal scenarios. Check net.ParseIP for supported syntax, only IPv4 is supported. + // This is generally used when BindOnIP would be the same across several nodes (ie: `0.0.0.0` or `::`) + // and for nat traversal scenarios. Check net.ParseIP for supported syntax BroadcastAddress string `yaml:"broadcastAddress"` } diff --git a/common/log/tag/tags.go b/common/log/tag/tags.go index 546487a1b39..95f2f5ca757 100644 --- a/common/log/tag/tags.go +++ b/common/log/tag/tags.go @@ -442,6 +442,11 @@ func IgnoredValue(v interface{}) ZapTag { return NewAnyTag("ignored-value", v) } +// Host returns tag for Host +func Host(h string) ZapTag { + return NewStringTag("host", h) +} + // Port returns tag for Port func Port(p int) ZapTag { return NewInt("port", p) diff --git a/common/membership/ringpop/factory.go b/common/membership/ringpop/factory.go index 7e1b3cdad02..3fef7bbd2dc 100644 --- a/common/membership/ringpop/factory.go +++ b/common/membership/ringpop/factory.go @@ -47,6 +47,7 @@ import ( "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/rpc/encryption" + "go.temporal.io/server/environment" ) const ( @@ -217,7 +218,7 @@ func (factory *factory) getListenIP() net.IP { } if factory.RPCConfig.BindOnLocalHost { - return net.IPv4(127, 0, 0, 1) + return net.ParseIP(environment.GetLocalhostIP()) } if len(factory.RPCConfig.BindOnIP) > 0 { @@ -229,6 +230,7 @@ func (factory *factory) getListenIP() net.IP { factory.Logger.Fatal("ListenIP failed, unable to parse bindOnIP value", tag.Address(factory.RPCConfig.BindOnIP)) return nil } + ip, err := config.ListenIP() if err != nil { factory.Logger.Fatal("ListenIP failed", tag.Error(err)) diff --git a/common/persistence/persistence-tests/persistence_test_base.go b/common/persistence/persistence-tests/persistence_test_base.go index bc60ee9ec96..68746155d44 100644 --- a/common/persistence/persistence-tests/persistence_test_base.go +++ b/common/persistence/persistence-tests/persistence_test_base.go @@ -155,7 +155,7 @@ func NewTestBaseWithSQL(options *TestBaseOptions) TestBase { case postgresql.PluginName: options.DBHost = environment.GetPostgreSQLAddress() case sqlite.PluginName: - options.DBHost = environment.Localhost + options.DBHost = environment.GetLocalhostIP() default: panic(fmt.Sprintf("unknown sql store driver: %v", options.SQLDBPluginName)) } diff --git a/common/persistence/persistence-tests/setup.go b/common/persistence/persistence-tests/setup.go index 960f3687448..9c14808417c 100644 --- a/common/persistence/persistence-tests/setup.go +++ b/common/persistence/persistence-tests/setup.go @@ -108,7 +108,7 @@ func GetSQLiteFileTestClusterOption() *TestBaseOptions { SQLDBPluginName: sqlite.PluginName, DBUsername: testSQLiteUser, DBPassword: testSQLitePassword, - DBHost: environment.Localhost, + DBHost: environment.GetLocalhostIP(), DBPort: 0, SchemaDir: testSQLiteSchemaDir, StoreType: config.StoreTypeSQL, @@ -122,7 +122,7 @@ func GetSQLiteMemoryTestClusterOption() *TestBaseOptions { SQLDBPluginName: sqlite.PluginName, DBUsername: testSQLiteUser, DBPassword: testSQLitePassword, - DBHost: environment.Localhost, + DBHost: environment.GetLocalhostIP(), DBPort: 0, SchemaDir: "", StoreType: config.StoreTypeSQL, diff --git a/common/persistence/tests/sqlite_test.go b/common/persistence/tests/sqlite_test.go index 9ce5ae6a3fe..acbca0817cb 100644 --- a/common/persistence/tests/sqlite_test.go +++ b/common/persistence/tests/sqlite_test.go @@ -58,7 +58,7 @@ func NewSQLiteMemoryConfig() *config.SQL { return &config.SQL{ User: "", Password: "", - ConnectAddr: environment.Localhost, + ConnectAddr: environment.GetLocalhostIP(), ConnectProtocol: "tcp", PluginName: "sqlite", DatabaseName: "default", @@ -71,7 +71,7 @@ func NewSQLiteFileConfig() *config.SQL { return &config.SQL{ User: "", Password: "", - ConnectAddr: environment.Localhost, + ConnectAddr: environment.GetLocalhostIP(), ConnectProtocol: "tcp", PluginName: "sqlite", DatabaseName: "test_" + persistencetests.GenerateRandomDBName(3), diff --git a/common/pprof/pprof.go b/common/pprof/pprof.go index 960692912d4..4d31ea6e49d 100644 --- a/common/pprof/pprof.go +++ b/common/pprof/pprof.go @@ -26,6 +26,7 @@ package pprof import ( "fmt" + "net" "net/http" _ "net/http/pprof" // DO NOT REMOVE THE LINE "sync/atomic" @@ -67,11 +68,19 @@ func (initializer *PProfInitializerImpl) Start() error { initializer.Logger.Info("PProf not started due to port not set") return nil } + host := initializer.PProf.Host + if host == "" { + // default to localhost which will favor ipv4 on dual stack + // environments - configure host as `::1` to bind on ipv6 localhost + host = "localhost" + } + + hostPort := net.JoinHostPort(host, fmt.Sprint(port)) if atomic.CompareAndSwapInt32(&pprofStatus, pprofNotInitialized, pprofInitialized) { go func() { - initializer.Logger.Info("PProf listen on ", tag.Port(port)) - err := http.ListenAndServe(fmt.Sprintf("localhost:%d", port), nil) + initializer.Logger.Info("PProf listen on ", tag.Host(host), tag.Port(port)) + err := http.ListenAndServe(hostPort, nil) if err != nil { initializer.Logger.Error("listen and serve err", tag.Error(err)) } diff --git a/common/rpc/rpc.go b/common/rpc/rpc.go index abb6b96b071..6bb7c6a1383 100644 --- a/common/rpc/rpc.go +++ b/common/rpc/rpc.go @@ -39,6 +39,7 @@ import ( "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/rpc/encryption" + "go.temporal.io/server/environment" ) var _ common.RPCFactory = (*RPCFactory)(nil) @@ -162,7 +163,7 @@ func getListenIP(cfg *config.RPC, logger log.Logger) net.IP { } if cfg.BindOnLocalHost { - return net.IPv4(127, 0, 0, 1) + return net.ParseIP(environment.GetLocalhostIP()) } if len(cfg.BindOnIP) > 0 { diff --git a/environment/env.go b/environment/env.go index 740dd9f2002..16cd26ce809 100644 --- a/environment/env.go +++ b/environment/env.go @@ -26,13 +26,16 @@ package environment import ( "fmt" + "net" "os" "strconv" ) const ( - // Localhost default localhost - Localhost = "127.0.0.1" + // LocalhostIP default localhost + LocalhostIP = "LOCALHOST_IP" + // Localhost default hostname + LocalhostIPDefault = "127.0.0.1" // CassandraSeeds env CassandraSeeds = "CASSANDRA_SEEDS" @@ -69,8 +72,15 @@ const ( // SetupEnv setup the necessary env func SetupEnv() { + if os.Getenv(LocalhostIP) == "" { + err := os.Setenv(LocalhostIP, lookupLocalhostIP("localhost")) + if err != nil { + panic(fmt.Sprintf("error setting env %v", LocalhostIP)) + } + } + if os.Getenv(CassandraSeeds) == "" { - err := os.Setenv(CassandraSeeds, Localhost) + err := os.Setenv(CassandraSeeds, LocalhostIP) if err != nil { panic(fmt.Sprintf("error setting env %v", CassandraSeeds)) } @@ -84,7 +94,7 @@ func SetupEnv() { } if os.Getenv(MySQLSeeds) == "" { - err := os.Setenv(MySQLSeeds, Localhost) + err := os.Setenv(MySQLSeeds, LocalhostIP) if err != nil { panic(fmt.Sprintf("error setting env %v", MySQLSeeds)) } @@ -98,7 +108,7 @@ func SetupEnv() { } if os.Getenv(PostgresSeeds) == "" { - err := os.Setenv(PostgresSeeds, Localhost) + err := os.Setenv(PostgresSeeds, LocalhostIP) if err != nil { panic(fmt.Sprintf("error setting env %v", PostgresSeeds)) } @@ -112,7 +122,7 @@ func SetupEnv() { } if os.Getenv(ESSeeds) == "" { - err := os.Setenv(ESSeeds, Localhost) + err := os.Setenv(ESSeeds, LocalhostIP) if err != nil { panic(fmt.Sprintf("error setting env %v", ESSeeds)) } @@ -133,11 +143,41 @@ func SetupEnv() { } } +func lookupLocalhostIP(domain string) string { + // lookup localhost and favor the first ipv4 address + // unless there are only ipv6 addresses available + ips, err := net.LookupIP(domain) + if err != nil || len(ips) == 0 { + // fallback to default instead of error + return LocalhostIPDefault + } + var listenIp net.IP + for _, ip := range ips { + listenIp = ip + if listenIp.To4() != nil { + break + } + } + return listenIp.String() +} + +// GetLocalhostIP returns the ip address of the localhost domain +func GetLocalhostIP() string { + localhostIP := os.Getenv(LocalhostIP) + ip := net.ParseIP(localhostIP) + if ip != nil { + // if localhost is an ip return it + return ip.String() + } + // otherwise, ignore the value and lookup `localhost` + return lookupLocalhostIP("localhost") +} + // GetCassandraAddress return the cassandra address func GetCassandraAddress() string { addr := os.Getenv(CassandraSeeds) if addr == "" { - addr = Localhost + addr = GetLocalhostIP() } return addr } @@ -159,7 +199,7 @@ func GetCassandraPort() int { func GetMySQLAddress() string { addr := os.Getenv(MySQLSeeds) if addr == "" { - addr = Localhost + addr = GetLocalhostIP() } return addr } @@ -181,7 +221,7 @@ func GetMySQLPort() int { func GetPostgreSQLAddress() string { addr := os.Getenv(PostgresSeeds) if addr == "" { - addr = Localhost + addr = GetLocalhostIP() } return addr } diff --git a/environment/env_test.go b/environment/env_test.go new file mode 100644 index 00000000000..81aa3b781a4 --- /dev/null +++ b/environment/env_test.go @@ -0,0 +1,61 @@ +// The MIT License +// +// Copyright (c) 2023 Temporal Technologies Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package environment + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLookupLocalhostIPSuccess(t *testing.T) { + t.Parallel() + a := assert.New(t) + ipString := lookupLocalhostIP("localhost") + ip := net.ParseIP(ipString) + // localhost needs to resolve to a loopback address + // whether it's ipv4 or ipv6 - the result depends on + // the system running this test + a.True(ip.IsLoopback()) +} + +func TestLookupLocalhostIPMissingHostname(t *testing.T) { + t.Parallel() + a := assert.New(t) + ipString := lookupLocalhostIP("") + ip := net.ParseIP(ipString) + a.True(ip.IsLoopback()) + // if host can't be found, use ipv4 loopback + a.Equal(ip.String(), LocalhostIPDefault) +} + +func TestLookupLocalhostIPWithIPv6(t *testing.T) { + t.Parallel() + a := assert.New(t) + ipString := lookupLocalhostIP("::1") + ip := net.ParseIP(ipString) + a.True(ip.IsLoopback()) + // return ipv6 if only ipv6 is available + a.Equal(ip, net.ParseIP("::1")) +} diff --git a/tests/testutils/certificate.go b/tests/testutils/certificate.go index 1208959721b..9c1d8bc9f66 100644 --- a/tests/testutils/certificate.go +++ b/tests/testutils/certificate.go @@ -57,11 +57,12 @@ func generateSelfSignedX509CA(commonName string, extUsage []x509.ExtKeyUsage, ke x509.KeyUsageDigitalSignature, } - if ip := net.ParseIP(commonName).To4(); ip != nil { - template.IPAddresses = []net.IP{ip} - + if ip := net.ParseIP(commonName); ip != nil { if ip.IsLoopback() { + template.IPAddresses = []net.IP{net.IPv6loopback, net.IPv4(127, 0, 0, 1)} template.DNSNames = []string{"localhost"} + } else { + template.IPAddresses = []net.IP{ip} } } else { template.DNSNames = []string{commonName} @@ -111,11 +112,12 @@ func generateServerX509UsingCAAndSerialNumber(commonName string, serialNumber in KeyUsage: x509.KeyUsageDigitalSignature, } - if ip := net.ParseIP(commonName).To4(); ip != nil { - template.IPAddresses = []net.IP{ip} - + if ip := net.ParseIP(commonName); ip != nil { if ip.IsLoopback() { + template.IPAddresses = []net.IP{net.IPv6loopback, net.IPv4(127, 0, 0, 1)} template.DNSNames = []string{"localhost"} + } else { + template.IPAddresses = []net.IP{ip} } } else { template.DNSNames = []string{commonName}