From f7da578b0bb7c98f9dfbf6b5f6e1277fed12c0a7 Mon Sep 17 00:00:00 2001 From: Hector Sanjuan Date: Thu, 21 Nov 2024 15:14:01 +0100 Subject: [PATCH] ProviderQueryManager: support "max" param on FindProvidersAsync This aligns the ProviderQueryManager with the routing.ContentDiscovery interface. --- bitswap/client/internal/session/session.go | 6 +- .../client/internal/session/session_test.go | 2 +- .../providerquerymanager.go | 47 ++++++++++-- .../providerquerymanager_test.go | 74 ++++++++++++------- 4 files changed, 96 insertions(+), 33 deletions(-) diff --git a/bitswap/client/internal/session/session.go b/bitswap/client/internal/session/session.go index 392fcd3716..942a6c2a5d 100644 --- a/bitswap/client/internal/session/session.go +++ b/bitswap/client/internal/session/session.go @@ -75,7 +75,7 @@ type SessionPeerManager interface { // ProviderFinder is used to find providers for a given key type ProviderFinder interface { // FindProvidersAsync searches for peers that provide the given CID - FindProvidersAsync(ctx context.Context, k cid.Cid) <-chan peer.AddrInfo + FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.AddrInfo } // opType is the kind of operation that is being processed by the event loop @@ -410,7 +410,9 @@ func (s *Session) findMorePeers(ctx context.Context, c cid.Cid) { go func(k cid.Cid) { ctx, span := internal.StartSpan(ctx, "Session.FindMorePeers") defer span.End() - for p := range s.providerFinder.FindProvidersAsync(ctx, k) { + // Max is set to -1. This means "use the default limit" in the + // provider query manager. + for p := range s.providerFinder.FindProvidersAsync(ctx, k, -1) { // When a provider indicates that it has a cid, it's equivalent to // the providing peer sending a HAVE span.AddEvent("FoundPeer") diff --git a/bitswap/client/internal/session/session_test.go b/bitswap/client/internal/session/session_test.go index 84a7f14ffb..061e298e59 100644 --- a/bitswap/client/internal/session/session_test.go +++ b/bitswap/client/internal/session/session_test.go @@ -116,7 +116,7 @@ func newFakeProviderFinder() *fakeProviderFinder { } } -func (fpf *fakeProviderFinder) FindProvidersAsync(ctx context.Context, k cid.Cid) <-chan peer.AddrInfo { +func (fpf *fakeProviderFinder) FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.AddrInfo { go func() { select { case fpf.findMorePeersRequested <- k: diff --git a/routing/providerquerymanager/providerquerymanager.go b/routing/providerquerymanager/providerquerymanager.go index a90b714088..2d66a153d3 100644 --- a/routing/providerquerymanager/providerquerymanager.go +++ b/routing/providerquerymanager/providerquerymanager.go @@ -166,8 +166,16 @@ func (pqm *ProviderQueryManager) setFindProviderTimeout(findProviderTimeout time pqm.timeoutMutex.Unlock() } -// FindProvidersAsync finds providers for the given block. -func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid) <-chan peer.AddrInfo { +// FindProvidersAsync finds providers for the given block. The max parameter +// controls how many will be returned at most. For a provider to be returned, +// we must have successfully connected to it. Setting max to -1 will use the +// configured MaxProviders. Setting max to 0 will return an unbounded number +// of providers. +func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid, max int) <-chan peer.AddrInfo { + if max < 0 { + max = pqm.maxProviders + } + inProgressRequestChan := make(chan inProgressRequest) var span trace.Span @@ -203,10 +211,10 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, case receivedInProgressRequest = <-inProgressRequestChan: } - return pqm.receiveProviders(sessionCtx, k, receivedInProgressRequest, func() { span.End() }) + return pqm.receiveProviders(sessionCtx, k, max, receivedInProgressRequest, func() { span.End() }) } -func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, receivedInProgressRequest inProgressRequest, onCloseFn func()) <-chan peer.AddrInfo { +func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, max int, receivedInProgressRequest inProgressRequest, onCloseFn func()) <-chan peer.AddrInfo { // maintains an unbuffered queue for incoming providers for given request for a given session // essentially, as a provider comes in, for a given CID, we want to immediately broadcast to all // sessions that queried that CID, without worrying about whether the client code is actually @@ -216,6 +224,9 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k receivedProviders := append([]peer.AddrInfo(nil), receivedInProgressRequest.providersSoFar[0:]...) incomingProviders := receivedInProgressRequest.incoming + // count how many providers we received from our workers etc. + // these providers should be peers we managed to connect to. + total := len(receivedProviders) go func() { defer close(returnedProviders) defer onCloseFn() @@ -231,6 +242,21 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k } return receivedProviders[0] } + + stopWhenMaxReached := func() { + if max > 0 && total >= max { + if incomingProviders != nil { + // drains incomingProviders. + pqm.cancelProviderRequest(sessionCtx, k, incomingProviders) + incomingProviders = nil + } + } + } + + // Handle the case when providersSoFar already is more than we + // need. + stopWhenMaxReached() + for len(receivedProviders) > 0 || incomingProviders != nil { select { case <-pqm.ctx.Done(): @@ -245,6 +271,13 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k incomingProviders = nil } else { receivedProviders = append(receivedProviders, provider) + total++ + stopWhenMaxReached() + // we do not return, we will loop on + // the case below until + // len(receivedProviders) == 0, which + // means they have all been sent out + // via returnedProviders } case outgoingProviders() <- nextProvider(): receivedProviders = receivedProviders[1:] @@ -293,7 +326,11 @@ func (pqm *ProviderQueryManager) findProviderWorker() { pqm.timeoutMutex.RUnlock() span := trace.SpanFromContext(findProviderCtx) span.AddEvent("StartFindProvidersAsync") - providers := pqm.router.FindProvidersAsync(findProviderCtx, k, pqm.maxProviders) + // We set count == 0. We will cancel the query + // manually once we have enough. This assumes the + // ContentDiscovery implementation does that, which a + // requirement per the libp2p/core/routing interface. + providers := pqm.router.FindProvidersAsync(findProviderCtx, k, 0) wg := &sync.WaitGroup{} for p := range providers { wg.Add(1) diff --git a/routing/providerquerymanager/providerquerymanager_test.go b/routing/providerquerymanager/providerquerymanager_test.go index 7254c64fcf..9d8ac8ed48 100644 --- a/routing/providerquerymanager/providerquerymanager_test.go +++ b/routing/providerquerymanager/providerquerymanager_test.go @@ -18,7 +18,7 @@ type fakeProviderDialer struct { connectDelay time.Duration } -type fakeProviderNetwork struct { +type fakeProviderDiscovery struct { peersFound []peer.ID delay time.Duration queriesMadeMutex sync.RWMutex @@ -31,7 +31,7 @@ func (fpd *fakeProviderDialer) Connect(context.Context, peer.AddrInfo) error { return fpd.connectError } -func (fpn *fakeProviderNetwork) FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.AddrInfo { +func (fpn *fakeProviderDiscovery) FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.AddrInfo { fpn.queriesMadeMutex.Lock() fpn.queriesMade++ fpn.liveQueries++ @@ -70,7 +70,7 @@ func mustNotErr[T any](out T, err error) T { func TestNormalSimultaneousFetch(t *testing.T) { peers := random.Peers(10) fpd := &fakeProviderDialer{} - fpn := &fakeProviderNetwork{ + fpn := &fakeProviderDiscovery{ peersFound: peers, delay: 1 * time.Millisecond, } @@ -81,8 +81,8 @@ func TestNormalSimultaneousFetch(t *testing.T) { sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0]) - secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[1]) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[1], 0) var firstPeersReceived []peer.AddrInfo for p := range firstRequestChan { @@ -108,7 +108,7 @@ func TestNormalSimultaneousFetch(t *testing.T) { func TestDedupingProviderRequests(t *testing.T) { peers := random.Peers(10) fpd := &fakeProviderDialer{} - fpn := &fakeProviderNetwork{ + fpn := &fakeProviderDiscovery{ peersFound: peers, delay: 1 * time.Millisecond, } @@ -119,8 +119,8 @@ func TestDedupingProviderRequests(t *testing.T) { sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) - secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) var firstPeersReceived []peer.AddrInfo for p := range firstRequestChan { @@ -149,7 +149,7 @@ func TestDedupingProviderRequests(t *testing.T) { func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) { peers := random.Peers(10) fpd := &fakeProviderDialer{} - fpn := &fakeProviderNetwork{ + fpn := &fakeProviderDiscovery{ peersFound: peers, delay: 1 * time.Millisecond, } @@ -162,10 +162,10 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) { // first session will cancel before done firstSessionCtx, firstCancel := context.WithTimeout(ctx, 3*time.Millisecond) defer firstCancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key) + firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key, 0) secondSessionCtx, secondCancel := context.WithTimeout(ctx, 5*time.Second) defer secondCancel() - secondRequestChan := providerQueryManager.FindProvidersAsync(secondSessionCtx, key) + secondRequestChan := providerQueryManager.FindProvidersAsync(secondSessionCtx, key, 0) var firstPeersReceived []peer.AddrInfo for p := range firstRequestChan { @@ -194,7 +194,7 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) { func TestCancelManagerExitsGracefully(t *testing.T) { peers := random.Peers(10) fpd := &fakeProviderDialer{} - fpn := &fakeProviderNetwork{ + fpn := &fakeProviderDiscovery{ peersFound: peers, delay: 1 * time.Millisecond, } @@ -208,8 +208,8 @@ func TestCancelManagerExitsGracefully(t *testing.T) { sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) defer cancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) - secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) var firstPeersReceived []peer.AddrInfo for p := range firstRequestChan { @@ -232,7 +232,7 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) { fpd := &fakeProviderDialer{ connectError: errors.New("not able to connect"), } - fpn := &fakeProviderNetwork{ + fpn := &fakeProviderDiscovery{ peersFound: peers, delay: 1 * time.Millisecond, } @@ -244,8 +244,8 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) { sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) defer cancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) - secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) + secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0) var firstPeersReceived []peer.AddrInfo for p := range firstRequestChan { @@ -265,7 +265,7 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) { func TestRateLimitingRequests(t *testing.T) { peers := random.Peers(10) fpd := &fakeProviderDialer{} - fpn := &fakeProviderNetwork{ + fpn := &fakeProviderDiscovery{ peersFound: peers, delay: 5 * time.Millisecond, } @@ -280,7 +280,7 @@ func TestRateLimitingRequests(t *testing.T) { defer cancel() var requestChannels []<-chan peer.AddrInfo for i := 0; i < providerQueryManager.maxInProcessRequests+1; i++ { - requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i])) + requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i], 0)) } time.Sleep(20 * time.Millisecond) fpn.queriesMadeMutex.Lock() @@ -305,7 +305,7 @@ func TestRateLimitingRequests(t *testing.T) { func TestFindProviderTimeout(t *testing.T) { peers := random.Peers(10) fpd := &fakeProviderDialer{} - fpn := &fakeProviderNetwork{ + fpn := &fakeProviderDiscovery{ peersFound: peers, delay: 10 * time.Millisecond, } @@ -317,7 +317,7 @@ func TestFindProviderTimeout(t *testing.T) { sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0]) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0) var firstPeersReceived []peer.AddrInfo for p := range firstRequestChan { firstPeersReceived = append(firstPeersReceived, p) @@ -330,7 +330,7 @@ func TestFindProviderTimeout(t *testing.T) { func TestFindProviderPreCanceled(t *testing.T) { peers := random.Peers(10) fpd := &fakeProviderDialer{} - fpn := &fakeProviderNetwork{ + fpn := &fakeProviderDiscovery{ peersFound: peers, delay: 1 * time.Millisecond, } @@ -342,7 +342,7 @@ func TestFindProviderPreCanceled(t *testing.T) { sessionCtx, cancel := context.WithCancel(ctx) cancel() - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0]) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0) if firstRequestChan == nil { t.Fatal("expected non-nil channel") } @@ -356,7 +356,7 @@ func TestFindProviderPreCanceled(t *testing.T) { func TestCancelFindProvidersAfterCompletion(t *testing.T) { peers := random.Peers(2) fpd := &fakeProviderDialer{} - fpn := &fakeProviderNetwork{ + fpn := &fakeProviderDiscovery{ peersFound: peers, delay: 1 * time.Millisecond, } @@ -367,7 +367,7 @@ func TestCancelFindProvidersAfterCompletion(t *testing.T) { keys := random.Cids(1) sessionCtx, cancel := context.WithCancel(ctx) - firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0]) + firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0) <-firstRequestChan // wait for everything to start. time.Sleep(10 * time.Millisecond) // wait for the incoming providres to stop. cancel() // cancel the context. @@ -385,3 +385,27 @@ func TestCancelFindProvidersAfterCompletion(t *testing.T) { } } } + +func TestLimitedProviders(t *testing.T) { + max := 5 + peers := random.Peers(10) + fpd := &fakeProviderDialer{} + fpn := &fakeProviderDiscovery{ + peersFound: peers, + delay: 1 * time.Millisecond, + } + ctx := context.Background() + providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxProviders(max))) + providerQueryManager.Startup() + providerQueryManager.setFindProviderTimeout(100 * time.Millisecond) + keys := random.Cids(1) + + providersChan := providerQueryManager.FindProvidersAsync(ctx, keys[0], -1) + total := 0 + for range providersChan { + total++ + } + if total != max { + t.Fatal("returned more providers than requested") + } +}