From 2bc17348ac16eaf82806b5a31c3e73d404a03f95 Mon Sep 17 00:00:00 2001 From: FZambia Date: Thu, 11 Jul 2024 21:24:18 +0300 Subject: [PATCH] add test --- client_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/client_test.go b/client_test.go index e6465fd6..4c313cd8 100644 --- a/client_test.go +++ b/client_test.go @@ -3804,11 +3804,20 @@ func asyncSubscribeClient(t testing.TB, client *Client, ch string) { require.NoError(t, err) } +// Not looking at unsubscribe result - just execute subscribe command. +func asyncUnsubscribeClient(t testing.TB, client *Client, ch string) { + rwWrapper := testReplyWriterWrapper() + err := client.handleUnsubscribe(&protocol.UnsubscribeRequest{ + Channel: ch, + }, &protocol.Command{Id: 1}, time.Now(), rwWrapper.rw) + require.NoError(t, err) +} + func TestClientUnsubscribeDuringSubscribe(t *testing.T) { t.Parallel() node := defaultNodeNoHandlers() - subscribedCh := make(chan struct{}, 2) - unsubscribedCh := make(chan struct{}, 2) + subscribedCh := make(chan struct{}, 1) + unsubscribedCh := make(chan struct{}, 1) node.OnConnect(func(client *Client) { client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) { go func() { @@ -3828,16 +3837,12 @@ func TestClientUnsubscribeDuringSubscribe(t *testing.T) { client := newTestClient(t, node, "42") connectClientV2(t, client) asyncSubscribeClient(t, client, "test") - client.Unsubscribe("test") + asyncUnsubscribeClient(t, client, "test") client.mu.Lock() _, ok := client.channels["test"] client.mu.Unlock() require.False(t, ok) waitWithTimeout(t, unsubscribedCh) - asyncSubscribeClient(t, client, "test") - err := client.close(DisconnectForceNoReconnect) - waitWithTimeout(t, unsubscribedCh) - require.NoError(t, err) } func TestClientUnsubscribeDuringSubscribeWithError(t *testing.T) { @@ -3862,7 +3867,7 @@ func TestClientUnsubscribeDuringSubscribeWithError(t *testing.T) { client := newTestClient(t, node, "42") connectClientV2(t, client) asyncSubscribeClient(t, client, "test") - client.Unsubscribe("test") + asyncUnsubscribeClient(t, client, "test") client.mu.Lock() _, ok := client.channels["test"] client.mu.Unlock() @@ -3871,3 +3876,33 @@ func TestClientUnsubscribeDuringSubscribeWithError(t *testing.T) { err := client.close(DisconnectForceNoReconnect) require.NoError(t, err) } + +func TestClientUnsubscribeDuringSubscribeCorrectChannels(t *testing.T) { + t.Parallel() + node := defaultNodeNoHandlers() + subscribedCh := make(chan struct{}) + node.OnConnect(func(client *Client) { + client.OnSubscribe(func(e SubscribeEvent, cb SubscribeCallback) { + go func() { + time.Sleep(1000 * time.Millisecond) + cb(SubscribeReply{}, nil) + close(subscribedCh) + }() + }) + client.OnUnsubscribe(func(e UnsubscribeEvent) { + }) + }) + defer func() { _ = node.Shutdown(context.Background()) }() + client := newTestClient(t, node, "42") + connectClientV2(t, client) + asyncSubscribeClient(t, client, "test") + asyncUnsubscribeClient(t, client, "test") + client.mu.Lock() + _, ok := client.channels["test"] + client.mu.Unlock() + require.False(t, ok) + <-subscribedCh + require.Equal(t, 0, node.Hub().NumChannels()) + err := client.close(DisconnectForceNoReconnect) + require.NoError(t, err) +}