diff --git a/Android/app/src/go/intra/split/retrier.go b/Android/app/src/go/intra/split/retrier.go index 07e9a208..c0bfc618 100644 --- a/Android/app/src/go/intra/split/retrier.go +++ b/Android/app/src/go/intra/split/retrier.go @@ -23,6 +23,7 @@ import ( "math/rand" "net" "sync" + "sync/atomic" "time" "github.com/Jigsaw-Code/getsni" @@ -62,8 +63,8 @@ type retrier struct { // Flag indicating when retry is finished or unnecessary. retryCompleteFlag chan struct{} // Flags indicating whether the caller has called CloseRead and CloseWrite. - readCloseFlag chan struct{} - writeCloseFlag chan struct{} + readCloseFlag atomic.Bool + writeCloseFlag atomic.Bool stats *RetryStats } @@ -87,11 +88,11 @@ func closed(c chan struct{}) bool { } func (r *retrier) readClosed() bool { - return closed(r.readCloseFlag) + return r.readCloseFlag.Load() } func (r *retrier) writeClosed() bool { - return closed(r.writeCloseFlag) + return r.writeCloseFlag.Load() } func (r *retrier) retryCompleted() bool { @@ -142,8 +143,6 @@ func DialWithSplitRetry(ctx context.Context, dialer *net.Dialer, addr *net.TCPAd conn: conn.(*net.TCPConn), timeout: timeout(before, after), retryCompleteFlag: make(chan struct{}), - readCloseFlag: make(chan struct{}), - writeCloseFlag: make(chan struct{}), stats: stats, } @@ -211,9 +210,7 @@ func (r *retrier) retry(buf []byte) (n int, err error) { } func (r *retrier) CloseRead() error { - if !r.readClosed() { - close(r.readCloseFlag) - } + r.readCloseFlag.Store(true) r.mutex.Lock() defer r.mutex.Unlock() return r.conn.CloseRead() @@ -363,9 +360,7 @@ func (r *retrier) ReadFrom(reader io.Reader) (bytes int64, err error) { } func (r *retrier) CloseWrite() error { - if !r.writeClosed() { - close(r.writeCloseFlag) - } + r.writeCloseFlag.Store(true) r.mutex.Lock() defer r.mutex.Unlock() return r.conn.CloseWrite()