diff --git a/transport/cluster.go b/transport/cluster.go index e95511ff..9bd26aed 100644 --- a/transport/cluster.go +++ b/transport/cluster.go @@ -8,12 +8,11 @@ import ( "sort" "strconv" "strings" + "sync/atomic" "time" "github.com/scylladb/scylla-go-driver/frame" . "github.com/scylladb/scylla-go-driver/frame/response" - - "go.uber.org/atomic" ) type ( @@ -26,7 +25,7 @@ type ( ) type Cluster struct { - topology atomic.Value // *topology + topology atomic.Pointer[topology] control *Conn cfg ConnConfig handledEvents []frame.EventType // This will probably be moved to config. @@ -121,7 +120,7 @@ func (c *Cluster) NewTokenAwareQueryInfo(t Token, ks string) (QueryInfo, error) // TODO overflow and negative modulo. func (c *Cluster) generateOffset() uint64 { - return c.queryInfoCounter.Inc() - 1 + return c.queryInfoCounter.Add(1) - 1 } // NewCluster also creates control connection and starts handling events and refreshing topology. @@ -443,7 +442,7 @@ func parseTokensFromRow(n *Node, r frame.Row, ring *Ring) error { } func (c *Cluster) Topology() *topology { - return c.topology.Load().(*topology) + return c.topology.Load() } func (c *Cluster) setTopology(t *topology) { diff --git a/transport/cluster_integration_test.go b/transport/cluster_integration_test.go index d8f29c13..3a752e2a 100644 --- a/transport/cluster_integration_test.go +++ b/transport/cluster_integration_test.go @@ -80,7 +80,7 @@ func TestClusterIntegration(t *testing.T) { } // There should be at least system keyspaces present. - if len(c.topology.Load().(*topology).keyspaces) == 0 { + if len(c.Topology().keyspaces) == 0 { t.Fatalf("Keyspaces failed to load") } diff --git a/transport/export_test.go b/transport/export_test.go index 84a4a911..29bfc8ff 100644 --- a/transport/export_test.go +++ b/transport/export_test.go @@ -3,7 +3,7 @@ package transport func (p *ConnPool) AllConns() []*Conn { var conns = make([]*Conn, len(p.conns)) for i, v := range p.conns { - conns[i], _ = v.Load().(*Conn) + conns[i] = v.Load() } return conns } diff --git a/transport/pool.go b/transport/pool.go index b41a21d4..ffa1ada8 100644 --- a/transport/pool.go +++ b/transport/pool.go @@ -6,11 +6,10 @@ import ( "log" "math" "net" + "sync/atomic" "time" . "github.com/scylladb/scylla-go-driver/frame/response" - - "go.uber.org/atomic" ) const poolCloseShard = -1 @@ -19,7 +18,7 @@ type ConnPool struct { host string nrShards int msbIgnore uint8 - conns []atomic.Value + conns []atomic.Pointer[Conn] connClosedCh chan int // notification channel for when connection is closed connObs ConnObserver } @@ -99,13 +98,11 @@ func (p *ConnPool) storeConn(conn *Conn) { } func (p *ConnPool) loadConn(shard int) *Conn { - conn, _ := p.conns[shard].Load().(*Conn) - return conn + return p.conns[shard].Load() } func (p *ConnPool) clearConn(shard int) bool { - conn, _ := p.conns[shard].Swap((*Conn)(nil)).(*Conn) - return conn != nil + return p.conns[shard].Swap(nil) != nil } func (p *ConnPool) Close() { @@ -115,7 +112,7 @@ func (p *ConnPool) Close() { // closeAll is called by PoolRefiller. func (p *ConnPool) closeAll() { for i := range p.conns { - if conn, ok := p.conns[i].Swap((*Conn)(nil)).(*Conn); ok { + if conn := p.conns[i].Swap(nil); conn != nil { conn.Close() } } @@ -168,7 +165,7 @@ func (r *PoolRefiller) init(ctx context.Context, host string) error { host: host, nrShards: int(ss.NrShards), msbIgnore: ss.MsbIgnore, - conns: make([]atomic.Value, int(ss.NrShards)), + conns: make([]atomic.Pointer[Conn], int(ss.NrShards)), connClosedCh: make(chan int, int(ss.NrShards)+1), connObs: r.cfg.ConnObserver, }