From 6988c2c71931739bba5f5bb53fcfda47f1a704a7 Mon Sep 17 00:00:00 2001 From: Daniel Mickens Date: Wed, 2 Mar 2022 09:04:58 -0500 Subject: [PATCH] Restore DNS Round Robin Behavior (#134) --- connection.go | 42 ++++++++++++++++++++++++++++++++---------- driver_test.go | 1 + 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/connection.go b/connection.go index a721d82..8ad2484 100644 --- a/connection.go +++ b/connection.go @@ -40,6 +40,7 @@ import ( "database/sql/driver" "encoding/binary" "fmt" + "math/rand" "net" "net/url" "os" @@ -277,18 +278,39 @@ func (v *connection) establishSocketConnection() (net.Conn, error) { // Failover: loop to try all hosts in the list err_msg := "" for i := 0; i < len(v.connHostsList); i++ { - // net.Dial will resolve the host to multiple IP addresses, - // and try each IP address in order until one succeeds. - conn, err := net.Dial("tcp", v.connHostsList[i]) + host, port, err := net.SplitHostPort(v.connHostsList[i]) if err != nil { - err_msg += fmt.Sprintf("\n '%s': %s", v.connHostsList[i], err.Error()) - } else { - if len(err_msg) != 0 { - connectionLogger.Debug("Failed to establish a connection to %s", err_msg) + // no host-port pair identified + err_msg += fmt.Sprintf("\n '%s': %s", host, err.Error()) + continue + } + ips, err := net.LookupIP(host) + if err != nil { + // failed to resolve any IPs from host + err_msg += fmt.Sprintf("\n '%s': %s", host, err.Error()) + continue + } + r := rand.New(rand.NewSource(time.Now().Unix())) + for _, j := range r.Perm(len(ips)) { + // j comes from random permutation of indexes - ips[j] will access a random resolved ip + ip := net.IP.String(ips[j]) + if strings.HasPrefix(ip, "::") { + //handle IPV6 shorthand + ip = "[" + ip + "]" } - connectionLogger.Debug("Established socket connection to %s", v.connHostsList[i]) - v.connHostsList = v.connHostsList[i:] - return conn, err + addrString := ip + ":" + string(port) + conn, err := net.Dial("tcp", addrString) + + if err != nil { + err_msg += fmt.Sprintf("\n '%s': %s", v.connHostsList[i], err.Error()) + } else { + if len(err_msg) != 0 { + connectionLogger.Debug("Failed to establish a connection to %s", err_msg) + } + connectionLogger.Debug("Established socket connection to %s", addrString) + v.connHostsList = v.connHostsList[i:] + return conn, err + } } } // All of the hosts failed diff --git a/driver_test.go b/driver_test.go index ad1f062..59fc0cf 100644 --- a/driver_test.go +++ b/driver_test.go @@ -60,6 +60,7 @@ var ( ctx context.Context ) +// The following assert functions compensate for Go having no native assertions func assertTrue(t *testing.T, v bool) { t.Helper()