diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 224ff1d6de..a6e28f9610 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -307,15 +307,11 @@ func (p *parser) parse(original string) error { } // If p.SRVMaxHosts is non-zero and is less than the number of hosts, randomly - // select SRVMaxHosts hosts from parsedHosts using the modern Fisher-Yates - // algorithm. - if p.SRVMaxHosts != 0 && p.SRVMaxHosts < len(parsedHosts) { - // TODO(GODRIVER-1876): Use rand#Shuffle after dropping Go 1.9 support. - n := len(parsedHosts) - for i := 0; i < n-1; i++ { - j := i + random.Intn(n-i) - parsedHosts[j], parsedHosts[i] = parsedHosts[i], parsedHosts[j] - } + // select SRVMaxHosts hosts from parsedHosts. + if p.SRVMaxHosts > 0 && p.SRVMaxHosts < len(parsedHosts) { + random.Shuffle(len(parsedHosts), func(i, j int) { + parsedHosts[i], parsedHosts[j] = parsedHosts[j], parsedHosts[i] + }) parsedHosts = parsedHosts[:p.SRVMaxHosts] } } diff --git a/x/mongo/driver/topology/polling_srv_records_test.go b/x/mongo/driver/topology/polling_srv_records_test.go index 5d1e934374..b35dfb60f9 100644 --- a/x/mongo/driver/topology/polling_srv_records_test.go +++ b/x/mongo/driver/topology/polling_srv_records_test.go @@ -371,9 +371,6 @@ func TestPollSRVRecordsMaxHosts(t *testing.T) { compareHosts(t, actualHosts, expectedHosts) }) t.Run("SRVMaxHosts is less than number of hosts", func(t *testing.T) { - // TODO: Enable with GODRIVER-2222. - t.Skipf("TODO: Enable with GODRIVER-2222") - recordsToAdd := []*net.SRV{{"localhost.test.build.10gen.cc.", 27019, 0, 0}, {"localhost.test.build.10gen.cc.", 27020, 0, 0}} recordsToRemove := []*net.SRV{{"localhost.test.build.10gen.cc.", 27018, 0, 0}} topo, disconnect := simulateSRVPoll(2, recordsToAdd, recordsToRemove) diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index 23c1d2479e..0f3ccdfd3d 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -547,19 +547,6 @@ func (t *Topology) pollSRVRecords() { t.pollHeartbeatTime.Store(false) } - // If t.cfg.srvMaxHosts is non-zero and is less than the number of hosts, randomly - // select srvMaxHosts hosts from parsedHosts using the modern Fisher-Yates - // algorithm. - if t.cfg.srvMaxHosts != 0 && t.cfg.srvMaxHosts < len(parsedHosts) { - // TODO(GODRIVER-1876): Use rand#Shuffle after dropping Go 1.9 support. - n := len(parsedHosts) - for i := 0; i < n-1; i++ { - j := i + random.Intn(n-i) - parsedHosts[j], parsedHosts[i] = parsedHosts[i], parsedHosts[j] - } - parsedHosts = parsedHosts[:t.cfg.srvMaxHosts] - } - cont := t.processSRVResults(parsedHosts) if !cont { break @@ -598,11 +585,25 @@ func (t *Topology) processSRVResults(parsedHosts []string) bool { t.fsm.removeServerByAddr(addr) t.publishServerClosedEvent(s.address) } + + // Now that we've removed all the hosts that disappeared from the SRV record, we need to add any + // new hosts added to the SRV record. If adding all of the new hosts would increase the number + // of servers past srvMaxHosts, shuffle the list of added hosts. + if t.cfg.srvMaxHosts > 0 && len(t.servers)+len(diff.Added) > t.cfg.srvMaxHosts { + random.Shuffle(len(diff.Added), func(i, j int) { + diff.Added[i], diff.Added[j] = diff.Added[j], diff.Added[i] + }) + } + // Add all added hosts until the number of servers reaches srvMaxHosts. for _, a := range diff.Added { + if t.cfg.srvMaxHosts > 0 && len(t.servers) >= t.cfg.srvMaxHosts { + break + } addr := address.Address(a).Canonicalize() _ = t.addServer(addr) t.fsm.addServer(addr) } + //store new description newDesc := description.Topology{ Kind: t.fsm.Kind,