Skip to content

Commit

Permalink
Improve ipv6 support (temporalio#4766)
Browse files Browse the repository at this point in the history
Temporal works with ipv6, but there are some clarifications in documentation and additional settings that could make support cleaner. Specifically:

* specify pprof host to override localhost default
* use net.JoinHostPort instead of fmt.Sprintf
* lookup `localhost` in various code paths
* add a test for ringpop using ipv6
* fixes comments to remove claim of ipv4 only

tested via added unit tests and local verification
  • Loading branch information
underrun authored Sep 7, 2023
1 parent 0d343ce commit e62c22f
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 28 deletions.
13 changes: 9 additions & 4 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ type (
PProf struct {
// Port is the port on which the PProf will bind to
Port int `yaml:"port"`
// Host defaults to `localhost` but can be overriden
// for instance in the case of dual stack IPv4/IPv6
Host string `yaml:"host"`
}

// RPC contains the rpc config items
Expand All @@ -90,9 +93,11 @@ type (
// Port used for membership listener
MembershipPort int `yaml:"membershipPort"`
// BindOnLocalHost is true if localhost is the bind address
// if neither BindOnLocalHost nor BindOnIP are set then an
// an attempt to discover an address is made
BindOnLocalHost bool `yaml:"bindOnLocalHost"`
// BindOnIP can be used to bind service on specific ip (eg. `0.0.0.0`) -
// check net.ParseIP for supported syntax, only IPv4 is supported,
// BindOnIP can be used to bind service on specific ip (eg. `0.0.0.0` or `::`)
// check net.ParseIP for supported syntax
// mutually exclusive with `BindOnLocalHost` option
BindOnIP string `yaml:"bindOnIP"`
// HTTPPort is the port on which HTTP will listen. If unset/0, HTTP will be
Expand Down Expand Up @@ -225,8 +230,8 @@ type (
// MaxJoinDuration is the max wait time to join the gossip ring
MaxJoinDuration time.Duration `yaml:"maxJoinDuration"`
// BroadcastAddress is used as the address that is communicated to remote nodes to connect on.
// This is generally used when BindOnIP would be the same across several nodes (ie: 0.0.0.0)
// and for nat traversal scenarios. Check net.ParseIP for supported syntax, only IPv4 is supported.
// This is generally used when BindOnIP would be the same across several nodes (ie: `0.0.0.0` or `::`)
// and for nat traversal scenarios. Check net.ParseIP for supported syntax
BroadcastAddress string `yaml:"broadcastAddress"`
}

Expand Down
5 changes: 5 additions & 0 deletions common/log/tag/tags.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,11 @@ func IgnoredValue(v interface{}) ZapTag {
return NewAnyTag("ignored-value", v)
}

// Host returns tag for Host
func Host(h string) ZapTag {
return NewStringTag("host", h)
}

// Port returns tag for Port
func Port(p int) ZapTag {
return NewInt("port", p)
Expand Down
4 changes: 3 additions & 1 deletion common/membership/ringpop/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/primitives"
"go.temporal.io/server/common/rpc/encryption"
"go.temporal.io/server/environment"
)

const (
Expand Down Expand Up @@ -217,7 +218,7 @@ func (factory *factory) getListenIP() net.IP {
}

if factory.RPCConfig.BindOnLocalHost {
return net.IPv4(127, 0, 0, 1)
return net.ParseIP(environment.GetLocalhostIP())
}

if len(factory.RPCConfig.BindOnIP) > 0 {
Expand All @@ -229,6 +230,7 @@ func (factory *factory) getListenIP() net.IP {
factory.Logger.Fatal("ListenIP failed, unable to parse bindOnIP value", tag.Address(factory.RPCConfig.BindOnIP))
return nil
}

ip, err := config.ListenIP()
if err != nil {
factory.Logger.Fatal("ListenIP failed", tag.Error(err))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func NewTestBaseWithSQL(options *TestBaseOptions) TestBase {
case postgresql.PluginName:
options.DBHost = environment.GetPostgreSQLAddress()
case sqlite.PluginName:
options.DBHost = environment.Localhost
options.DBHost = environment.GetLocalhostIP()
default:
panic(fmt.Sprintf("unknown sql store driver: %v", options.SQLDBPluginName))
}
Expand Down
4 changes: 2 additions & 2 deletions common/persistence/persistence-tests/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func GetSQLiteFileTestClusterOption() *TestBaseOptions {
SQLDBPluginName: sqlite.PluginName,
DBUsername: testSQLiteUser,
DBPassword: testSQLitePassword,
DBHost: environment.Localhost,
DBHost: environment.GetLocalhostIP(),
DBPort: 0,
SchemaDir: testSQLiteSchemaDir,
StoreType: config.StoreTypeSQL,
Expand All @@ -122,7 +122,7 @@ func GetSQLiteMemoryTestClusterOption() *TestBaseOptions {
SQLDBPluginName: sqlite.PluginName,
DBUsername: testSQLiteUser,
DBPassword: testSQLitePassword,
DBHost: environment.Localhost,
DBHost: environment.GetLocalhostIP(),
DBPort: 0,
SchemaDir: "",
StoreType: config.StoreTypeSQL,
Expand Down
4 changes: 2 additions & 2 deletions common/persistence/tests/sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func NewSQLiteMemoryConfig() *config.SQL {
return &config.SQL{
User: "",
Password: "",
ConnectAddr: environment.Localhost,
ConnectAddr: environment.GetLocalhostIP(),
ConnectProtocol: "tcp",
PluginName: "sqlite",
DatabaseName: "default",
Expand All @@ -71,7 +71,7 @@ func NewSQLiteFileConfig() *config.SQL {
return &config.SQL{
User: "",
Password: "",
ConnectAddr: environment.Localhost,
ConnectAddr: environment.GetLocalhostIP(),
ConnectProtocol: "tcp",
PluginName: "sqlite",
DatabaseName: "test_" + persistencetests.GenerateRandomDBName(3),
Expand Down
13 changes: 11 additions & 2 deletions common/pprof/pprof.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ package pprof

import (
"fmt"
"net"
"net/http"
_ "net/http/pprof" // DO NOT REMOVE THE LINE
"sync/atomic"
Expand Down Expand Up @@ -67,11 +68,19 @@ func (initializer *PProfInitializerImpl) Start() error {
initializer.Logger.Info("PProf not started due to port not set")
return nil
}
host := initializer.PProf.Host
if host == "" {
// default to localhost which will favor ipv4 on dual stack
// environments - configure host as `::1` to bind on ipv6 localhost
host = "localhost"
}

hostPort := net.JoinHostPort(host, fmt.Sprint(port))

if atomic.CompareAndSwapInt32(&pprofStatus, pprofNotInitialized, pprofInitialized) {
go func() {
initializer.Logger.Info("PProf listen on ", tag.Port(port))
err := http.ListenAndServe(fmt.Sprintf("localhost:%d", port), nil)
initializer.Logger.Info("PProf listen on ", tag.Host(host), tag.Port(port))
err := http.ListenAndServe(hostPort, nil)
if err != nil {
initializer.Logger.Error("listen and serve err", tag.Error(err))
}
Expand Down
3 changes: 2 additions & 1 deletion common/rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"go.temporal.io/server/common/log/tag"
"go.temporal.io/server/common/primitives"
"go.temporal.io/server/common/rpc/encryption"
"go.temporal.io/server/environment"
)

var _ common.RPCFactory = (*RPCFactory)(nil)
Expand Down Expand Up @@ -162,7 +163,7 @@ func getListenIP(cfg *config.RPC, logger log.Logger) net.IP {
}

if cfg.BindOnLocalHost {
return net.IPv4(127, 0, 0, 1)
return net.ParseIP(environment.GetLocalhostIP())
}

if len(cfg.BindOnIP) > 0 {
Expand Down
58 changes: 49 additions & 9 deletions environment/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ package environment

import (
"fmt"
"net"
"os"
"strconv"
)

const (
// Localhost default localhost
Localhost = "127.0.0.1"
// LocalhostIP default localhost
LocalhostIP = "LOCALHOST_IP"
// Localhost default hostname
LocalhostIPDefault = "127.0.0.1"

// CassandraSeeds env
CassandraSeeds = "CASSANDRA_SEEDS"
Expand Down Expand Up @@ -69,8 +72,15 @@ const (

// SetupEnv setup the necessary env
func SetupEnv() {
if os.Getenv(LocalhostIP) == "" {
err := os.Setenv(LocalhostIP, lookupLocalhostIP("localhost"))
if err != nil {
panic(fmt.Sprintf("error setting env %v", LocalhostIP))
}
}

if os.Getenv(CassandraSeeds) == "" {
err := os.Setenv(CassandraSeeds, Localhost)
err := os.Setenv(CassandraSeeds, LocalhostIP)
if err != nil {
panic(fmt.Sprintf("error setting env %v", CassandraSeeds))
}
Expand All @@ -84,7 +94,7 @@ func SetupEnv() {
}

if os.Getenv(MySQLSeeds) == "" {
err := os.Setenv(MySQLSeeds, Localhost)
err := os.Setenv(MySQLSeeds, LocalhostIP)
if err != nil {
panic(fmt.Sprintf("error setting env %v", MySQLSeeds))
}
Expand All @@ -98,7 +108,7 @@ func SetupEnv() {
}

if os.Getenv(PostgresSeeds) == "" {
err := os.Setenv(PostgresSeeds, Localhost)
err := os.Setenv(PostgresSeeds, LocalhostIP)
if err != nil {
panic(fmt.Sprintf("error setting env %v", PostgresSeeds))
}
Expand All @@ -112,7 +122,7 @@ func SetupEnv() {
}

if os.Getenv(ESSeeds) == "" {
err := os.Setenv(ESSeeds, Localhost)
err := os.Setenv(ESSeeds, LocalhostIP)
if err != nil {
panic(fmt.Sprintf("error setting env %v", ESSeeds))
}
Expand All @@ -133,11 +143,41 @@ func SetupEnv() {
}
}

func lookupLocalhostIP(domain string) string {
// lookup localhost and favor the first ipv4 address
// unless there are only ipv6 addresses available
ips, err := net.LookupIP(domain)
if err != nil || len(ips) == 0 {
// fallback to default instead of error
return LocalhostIPDefault
}
var listenIp net.IP
for _, ip := range ips {
listenIp = ip
if listenIp.To4() != nil {
break
}
}
return listenIp.String()
}

// GetLocalhostIP returns the ip address of the localhost domain
func GetLocalhostIP() string {
localhostIP := os.Getenv(LocalhostIP)
ip := net.ParseIP(localhostIP)
if ip != nil {
// if localhost is an ip return it
return ip.String()
}
// otherwise, ignore the value and lookup `localhost`
return lookupLocalhostIP("localhost")
}

// GetCassandraAddress return the cassandra address
func GetCassandraAddress() string {
addr := os.Getenv(CassandraSeeds)
if addr == "" {
addr = Localhost
addr = GetLocalhostIP()
}
return addr
}
Expand All @@ -159,7 +199,7 @@ func GetCassandraPort() int {
func GetMySQLAddress() string {
addr := os.Getenv(MySQLSeeds)
if addr == "" {
addr = Localhost
addr = GetLocalhostIP()
}
return addr
}
Expand All @@ -181,7 +221,7 @@ func GetMySQLPort() int {
func GetPostgreSQLAddress() string {
addr := os.Getenv(PostgresSeeds)
if addr == "" {
addr = Localhost
addr = GetLocalhostIP()
}
return addr
}
Expand Down
61 changes: 61 additions & 0 deletions environment/env_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// The MIT License
//
// Copyright (c) 2023 Temporal Technologies Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package environment

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
)

func TestLookupLocalhostIPSuccess(t *testing.T) {
t.Parallel()
a := assert.New(t)
ipString := lookupLocalhostIP("localhost")
ip := net.ParseIP(ipString)
// localhost needs to resolve to a loopback address
// whether it's ipv4 or ipv6 - the result depends on
// the system running this test
a.True(ip.IsLoopback())
}

func TestLookupLocalhostIPMissingHostname(t *testing.T) {
t.Parallel()
a := assert.New(t)
ipString := lookupLocalhostIP("")
ip := net.ParseIP(ipString)
a.True(ip.IsLoopback())
// if host can't be found, use ipv4 loopback
a.Equal(ip.String(), LocalhostIPDefault)
}

func TestLookupLocalhostIPWithIPv6(t *testing.T) {
t.Parallel()
a := assert.New(t)
ipString := lookupLocalhostIP("::1")
ip := net.ParseIP(ipString)
a.True(ip.IsLoopback())
// return ipv6 if only ipv6 is available
a.Equal(ip, net.ParseIP("::1"))
}
14 changes: 8 additions & 6 deletions tests/testutils/certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ func generateSelfSignedX509CA(commonName string, extUsage []x509.ExtKeyUsage, ke
x509.KeyUsageDigitalSignature,
}

if ip := net.ParseIP(commonName).To4(); ip != nil {
template.IPAddresses = []net.IP{ip}

if ip := net.ParseIP(commonName); ip != nil {
if ip.IsLoopback() {
template.IPAddresses = []net.IP{net.IPv6loopback, net.IPv4(127, 0, 0, 1)}
template.DNSNames = []string{"localhost"}
} else {
template.IPAddresses = []net.IP{ip}
}
} else {
template.DNSNames = []string{commonName}
Expand Down Expand Up @@ -111,11 +112,12 @@ func generateServerX509UsingCAAndSerialNumber(commonName string, serialNumber in
KeyUsage: x509.KeyUsageDigitalSignature,
}

if ip := net.ParseIP(commonName).To4(); ip != nil {
template.IPAddresses = []net.IP{ip}

if ip := net.ParseIP(commonName); ip != nil {
if ip.IsLoopback() {
template.IPAddresses = []net.IP{net.IPv6loopback, net.IPv4(127, 0, 0, 1)}
template.DNSNames = []string{"localhost"}
} else {
template.IPAddresses = []net.IP{ip}
}
} else {
template.DNSNames = []string{commonName}
Expand Down

0 comments on commit e62c22f

Please sign in to comment.