diff --git a/atomic/chan.go b/atomic/chan.go index 75539cf..957b221 100644 --- a/atomic/chan.go +++ b/atomic/chan.go @@ -32,27 +32,26 @@ type Chan[T any] struct { } // Load atomically loads and returns the value stored in c. -func (c *Chan[T]) Load() chan T { return ptr2Ch[T](atomic.LoadPointer(&c.v)) } +func (c *Chan[T]) Load() (ch chan T) { + *(*unsafe.Pointer)(unsafe.Pointer(&ch)) = atomic.LoadPointer(&c.v) + + return +} // Store atomically stores ch into c. -func (c *Chan[T]) Store(ch chan T) { atomic.StorePointer(&c.v, ch2Ptr(ch)) } +func (c *Chan[T]) Store(ch chan T) { + atomic.StorePointer(&c.v, *(*unsafe.Pointer)(unsafe.Pointer(&ch))) +} // Swap atomically stores new into c and returns the previous value. func (c *Chan[T]) Swap(new chan T) (old chan T) { - return ptr2Ch[T](atomic.SwapPointer(&c.v, ch2Ptr(new))) + *(*unsafe.Pointer)(unsafe.Pointer(&old)) = atomic.SwapPointer(&c.v, *(*unsafe.Pointer)(unsafe.Pointer(&new))) + + return } // CompareAndSwap executes the compare-and-swap operation for c. func (c *Chan[T]) CompareAndSwap(old, new chan T) (swapped bool) { - return atomic.CompareAndSwapPointer(&c.v, ch2Ptr(old), ch2Ptr(new)) -} - -// ch2Ptr casts from a channel to a pointer. -func ch2Ptr[T any](ch chan T) unsafe.Pointer { - return *(*unsafe.Pointer)(unsafe.Pointer(&ch)) -} - -// ptr2Ch casts from a pointer to a channel. -func ptr2Ch[T any](ptr unsafe.Pointer) chan T { - return *(*chan T)(unsafe.Pointer(&ptr)) + return atomic.CompareAndSwapPointer(&c.v, + *(*unsafe.Pointer)(unsafe.Pointer(&old)), *(*unsafe.Pointer)(unsafe.Pointer(&new))) } diff --git a/lazy.go b/lazy.go index 0ca16a7..19023d3 100644 --- a/lazy.go +++ b/lazy.go @@ -27,8 +27,8 @@ type Lazy struct { // Close closes the done channel. You shouldn't close the channel twice. func (l *Lazy) Close() { - if ch := l.done.Swap(closedChan); ch != nil && ch != closedChan { - close(ch) + if done := l.done.Swap(closedChan); done != nil && done != closedChan { + close(done) } } @@ -48,15 +48,22 @@ func (l *Lazy) Done() <-chan struct{} { // Closed returns true if the done channel is closed. func (l *Lazy) Closed() bool { - if done := l.done.Load(); done != nil { + done := l.done.Load() + switch done { + case nil: + return false + + case closedChan: + return true + + default: select { case <-done: return true default: + return false } } - - return false } func (l *Lazy) String() string { diff --git a/lazy_test.go b/lazy_test.go index a8350d3..862206f 100644 --- a/lazy_test.go +++ b/lazy_test.go @@ -24,7 +24,7 @@ import ( "fillmore-labs.com/lazydone" ) -func TestUnsafeDone(t *testing.T) { +func TestDone(t *testing.T) { t.Parallel() for i := 0; i < 1_000; i++ { @@ -45,7 +45,7 @@ func TestUnsafeDone(t *testing.T) { } } -func TestUnsafeClosed(t *testing.T) { +func TestClosed(t *testing.T) { t.Parallel() var lazy lazydone.Lazy if lazy.Closed() { @@ -68,3 +68,30 @@ func TestUnsafeClosed(t *testing.T) { t.Error("Expected lazy to be closed after Close()") } } + +func TestClosedConcurrency(t *testing.T) { + t.Parallel() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + var lazy lazydone.Lazy + + wg.Add(3) + go func() { + <-lazy.Done() + wg.Done() + }() + go func() { + for !lazy.Closed() { //nolint:revive + // Spin, we want to hit the “select on closed channel” branch + } + wg.Done() + }() + go func() { + lazy.Close() + wg.Done() + }() + } + + wg.Wait() +} diff --git a/pointer_test.go b/pointer_test.go index 58eca50..8df8175 100644 --- a/pointer_test.go +++ b/pointer_test.go @@ -24,7 +24,7 @@ import ( "fillmore-labs.com/lazydone" ) -func TestDone(t *testing.T) { +func TestSafeDone(t *testing.T) { t.Parallel() for i := 0; i < 1_000; i++ { @@ -45,7 +45,7 @@ func TestDone(t *testing.T) { } } -func TestClosed(t *testing.T) { +func TestSafeClosed(t *testing.T) { t.Parallel() var lazy lazydone.SafeLazy if lazy.Closed() {