From 0da0e99556ce4426902c6a31ba0ea30c00ed320f Mon Sep 17 00:00:00 2001 From: Pavel Zbitskiy <65323360+algorandskiy@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:16:15 -0400 Subject: [PATCH] network: fix outgoing HTTP rate limiting (#6118) --- network/limitcaller/rateLimitingTransport.go | 49 +++++++------ .../limitcaller/rateLimitingTransport_test.go | 72 +++++++++++++++++++ network/p2p/http.go | 6 +- network/p2pNetwork.go | 9 +-- network/p2pNetwork_test.go | 2 +- network/wsNetwork.go | 19 +++-- network/wsNetwork_test.go | 39 ++++++++++ 7 files changed, 155 insertions(+), 41 deletions(-) create mode 100644 network/limitcaller/rateLimitingTransport_test.go diff --git a/network/limitcaller/rateLimitingTransport.go b/network/limitcaller/rateLimitingTransport.go index de68c9b371..7877c879d1 100644 --- a/network/limitcaller/rateLimitingTransport.go +++ b/network/limitcaller/rateLimitingTransport.go @@ -22,7 +22,6 @@ import ( "time" "github.com/algorand/go-algorand/util" - "github.com/libp2p/go-libp2p/core/peer" ) // ConnectionTimeStore is a subset of the phonebook that is used to store the connection times. @@ -31,12 +30,12 @@ type ConnectionTimeStore interface { UpdateConnectionTime(addrOrPeerID string, provisionalTime time.Time) bool } -// RateLimitingTransport is the transport for execute a single HTTP transaction, obtaining the Response for a given Request. -type RateLimitingTransport struct { +// RateLimitingBoundTransport is the transport for execute a single HTTP transaction, obtaining the Response for a given Request. +type RateLimitingBoundTransport struct { phonebook ConnectionTimeStore innerTransport http.RoundTripper queueingTimeout time.Duration - targetAddr interface{} // target address for the p2p http request + addrOrPeerID string } // DefaultQueueingTimeout is the default timeout for queueing the request. @@ -46,9 +45,10 @@ const DefaultQueueingTimeout = 10 * time.Second // queueing the current request before the request attempt could be made. var ErrConnectionQueueingTimeout = errors.New("rateLimitingTransport: queueing timeout") -// MakeRateLimitingTransport creates a rate limiting http transport that would limit the requests rate -// according to the entries in the phonebook. -func MakeRateLimitingTransport(phonebook ConnectionTimeStore, queueingTimeout time.Duration, dialer *Dialer, maxIdleConnsPerHost int) RateLimitingTransport { +// MakeRateLimitingBoundTransport creates a rate limiting http transport that that: +// 1. would limit the requests rate according to the entries in the phonebook. +// 2. is bound to a specific target. +func MakeRateLimitingBoundTransport(phonebook ConnectionTimeStore, queueingTimeout time.Duration, dialer *Dialer, maxIdleConnsPerHost int, target string) RateLimitingBoundTransport { defaultTransport := http.DefaultTransport.(*http.Transport) innerTransport := &http.Transport{ Proxy: defaultTransport.Proxy, @@ -59,37 +59,36 @@ func MakeRateLimitingTransport(phonebook ConnectionTimeStore, queueingTimeout ti ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout, MaxIdleConnsPerHost: maxIdleConnsPerHost, } - return MakeRateLimitingTransportWithRoundTripper(phonebook, queueingTimeout, innerTransport, nil, maxIdleConnsPerHost) + return MakeRateLimitingBoundTransportWithRoundTripper(phonebook, queueingTimeout, innerTransport, target) } -// MakeRateLimitingTransportWithRoundTripper creates a rate limiting http transport that would limit the requests rate -// according to the entries in the phonebook. -func MakeRateLimitingTransportWithRoundTripper(phonebook ConnectionTimeStore, queueingTimeout time.Duration, rt http.RoundTripper, target interface{}, maxIdleConnsPerHost int) RateLimitingTransport { - return RateLimitingTransport{ +// MakeRateLimitingBoundTransportWithRoundTripper creates a rate limiting http transport that: +// 1. would limit the requests rate according to the entries in the phonebook. +// 2. is bound to a specific target. +func MakeRateLimitingBoundTransportWithRoundTripper(phonebook ConnectionTimeStore, queueingTimeout time.Duration, rt http.RoundTripper, target string) RateLimitingBoundTransport { + return RateLimitingBoundTransport{ phonebook: phonebook, innerTransport: rt, queueingTimeout: queueingTimeout, - targetAddr: target, + addrOrPeerID: target, } } // RoundTrip connects to the address on the named network using the provided context. // It waits if needed not to exceed connectionsRateLimitingCount. -func (r *RateLimitingTransport) RoundTrip(req *http.Request) (res *http.Response, err error) { +func (r *RateLimitingBoundTransport) RoundTrip(req *http.Request) (res *http.Response, err error) { var waitTime time.Duration var provisionalTime time.Time - queueingDeadline := time.Now().Add(r.queueingTimeout) - addrOrPeerID := req.Host - // p2p/http clients have per-connection transport and address info so use that - if len(req.Host) == 0 && req.URL != nil && len(req.URL.Host) == 0 { - addrInfo, ok := r.targetAddr.(*peer.AddrInfo) - if !ok { - return nil, errors.New("rateLimitingTransport: request without Host/URL and targetAddr is not a peer.AddrInfo") - } - addrOrPeerID = string(addrInfo.ID) + if r.addrOrPeerID == "" { + return nil, errors.New("rateLimitingTransport: target not set") + } + if req.URL != nil && req.URL.Host != "" && req.URL.Host != r.addrOrPeerID { + return nil, errors.New("rateLimitingTransport: request URL host does not match the target") } + + queueingDeadline := time.Now().Add(r.queueingTimeout) for { - _, waitTime, provisionalTime = r.phonebook.GetConnectionWaitTime(addrOrPeerID) + _, waitTime, provisionalTime = r.phonebook.GetConnectionWaitTime(r.addrOrPeerID) if waitTime == 0 { break // break out of the loop and proceed to the connection } @@ -101,6 +100,6 @@ func (r *RateLimitingTransport) RoundTrip(req *http.Request) (res *http.Response return nil, ErrConnectionQueueingTimeout } res, err = r.innerTransport.RoundTrip(req) - r.phonebook.UpdateConnectionTime(addrOrPeerID, provisionalTime) + r.phonebook.UpdateConnectionTime(r.addrOrPeerID, provisionalTime) return } diff --git a/network/limitcaller/rateLimitingTransport_test.go b/network/limitcaller/rateLimitingTransport_test.go new file mode 100644 index 0000000000..155ed8310f --- /dev/null +++ b/network/limitcaller/rateLimitingTransport_test.go @@ -0,0 +1,72 @@ +// Copyright (C) 2019-2024 Algorand, Inc. +// This file is part of go-algorand +// +// go-algorand is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// go-algorand is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with go-algorand. If not, see . + +package limitcaller + +import ( + "net/http" + "testing" + "time" + + "github.com/algorand/go-algorand/test/partitiontest" + "github.com/stretchr/testify/require" +) + +type ctStore struct { + t *testing.T + getCnt uint64 +} + +func (c *ctStore) GetConnectionWaitTime(addrOrPeerID string) (bool, time.Duration, time.Time) { + require.NotEmpty(c.t, addrOrPeerID) + c.getCnt++ + return false, 0, time.Time{} +} + +func (c *ctStore) UpdateConnectionTime(addrOrPeerID string, provisionalTime time.Time) bool { + require.NotEmpty(c.t, addrOrPeerID) + return false +} + +type emptyRoundTripper struct{} + +func (e *emptyRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { return nil, nil } + +func TestRoundTrip(t *testing.T) { + partitiontest.PartitionTest(t) + t.Parallel() + + ctStore := ctStore{t: t} + rtt := MakeRateLimitingBoundTransportWithRoundTripper(&ctStore, 0, &emptyRoundTripper{}, "") + req := &http.Request{} + _, err := rtt.RoundTrip(req) + require.ErrorContains(t, err, "target not set") + require.Equal(t, uint64(0), ctStore.getCnt) + + rtt = MakeRateLimitingBoundTransportWithRoundTripper(&ctStore, 0, &emptyRoundTripper{}, "mytarget") + req, err = http.NewRequest("GET", "https://example.com/test", nil) + require.NoError(t, err) + _, err = rtt.RoundTrip(req) + require.ErrorContains(t, err, "URL host does not match the target") + require.Equal(t, uint64(0), ctStore.getCnt) + + rtt = MakeRateLimitingBoundTransportWithRoundTripper(&ctStore, 0, &emptyRoundTripper{}, "mytarget") + req, err = http.NewRequest("GET", "/test", nil) + require.NoError(t, err) + _, err = rtt.RoundTrip(req) + require.NoError(t, err) + require.Equal(t, uint64(1), ctStore.getCnt) +} diff --git a/network/p2p/http.go b/network/p2p/http.go index 07f27afff1..633a13713d 100644 --- a/network/p2p/http.go +++ b/network/p2p/http.go @@ -88,13 +88,13 @@ func MakeHTTPClient(addrInfo *peer.AddrInfo) (*http.Client, error) { } // MakeHTTPClientWithRateLimit creates a http.Client that uses libp2p transport for a given protocol and peer address. -func MakeHTTPClientWithRateLimit(addrInfo *peer.AddrInfo, pstore limitcaller.ConnectionTimeStore, queueingTimeout time.Duration, maxIdleConnsPerHost int) (*http.Client, error) { +func MakeHTTPClientWithRateLimit(addrInfo *peer.AddrInfo, pstore limitcaller.ConnectionTimeStore, queueingTimeout time.Duration) (*http.Client, error) { cl, err := MakeHTTPClient(addrInfo) if err != nil { return nil, err } - rlrt := limitcaller.MakeRateLimitingTransportWithRoundTripper(pstore, queueingTimeout, cl.Transport, addrInfo, maxIdleConnsPerHost) - cl.Transport = &rlrt + rltr := limitcaller.MakeRateLimitingBoundTransportWithRoundTripper(pstore, queueingTimeout, cl.Transport, string(addrInfo.ID)) + cl.Transport = &rltr return cl, nil } diff --git a/network/p2pNetwork.go b/network/p2pNetwork.go index d3af60a223..f88660b653 100644 --- a/network/p2pNetwork.go +++ b/network/p2pNetwork.go @@ -613,8 +613,7 @@ func addrInfoToWsPeerCore(n *P2PNetwork, addrInfo *peer.AddrInfo) (wsPeerCore, b } addr := mas[0].String() - maxIdleConnsPerHost := int(n.config.ConnectionsRateLimitingCount) - client, err := p2p.MakeHTTPClientWithRateLimit(addrInfo, n.pstore, limitcaller.DefaultQueueingTimeout, maxIdleConnsPerHost) + client, err := p2p.MakeHTTPClientWithRateLimit(addrInfo, n.pstore, limitcaller.DefaultQueueingTimeout) if err != nil { n.log.Warnf("MakeHTTPClient failed: %v", err) return wsPeerCore{}, false @@ -720,8 +719,7 @@ func (n *P2PNetwork) GetHTTPClient(address string) (*http.Client, error) { if err != nil { return nil, err } - maxIdleConnsPerHost := int(n.config.ConnectionsRateLimitingCount) - return p2p.MakeHTTPClientWithRateLimit(addrInfo, n.pstore, limitcaller.DefaultQueueingTimeout, maxIdleConnsPerHost) + return p2p.MakeHTTPClientWithRateLimit(addrInfo, n.pstore, limitcaller.DefaultQueueingTimeout) } // OnNetworkAdvance notifies the network library that the agreement protocol was able to make a notable progress. @@ -774,8 +772,7 @@ func (n *P2PNetwork) wsStreamHandler(ctx context.Context, p2pPeer peer.ID, strea // create a wsPeer for this stream and added it to the peers map. addrInfo := &peer.AddrInfo{ID: p2pPeer, Addrs: []multiaddr.Multiaddr{ma}} - maxIdleConnsPerHost := int(n.config.ConnectionsRateLimitingCount) - client, err := p2p.MakeHTTPClientWithRateLimit(addrInfo, n.pstore, limitcaller.DefaultQueueingTimeout, maxIdleConnsPerHost) + client, err := p2p.MakeHTTPClientWithRateLimit(addrInfo, n.pstore, limitcaller.DefaultQueueingTimeout) if err != nil { n.log.Warnf("Cannot construct HTTP Client for %s: %v", p2pPeer, err) client = nil diff --git a/network/p2pNetwork_test.go b/network/p2pNetwork_test.go index e2c231f843..dcc641c350 100644 --- a/network/p2pNetwork_test.go +++ b/network/p2pNetwork_test.go @@ -783,7 +783,7 @@ func TestP2PHTTPHandler(t *testing.T) { pstore, err := peerstore.MakePhonebook(0, 10*time.Second) require.NoError(t, err) pstore.AddPersistentPeers([]*peer.AddrInfo{&peerInfoA}, "net", phonebook.PhoneBookEntryRelayRole) - httpClient, err = p2p.MakeHTTPClientWithRateLimit(&peerInfoA, pstore, 1*time.Second, 1) + httpClient, err = p2p.MakeHTTPClientWithRateLimit(&peerInfoA, pstore, 1*time.Second) require.NoError(t, err) _, err = httpClient.Get("/test") require.ErrorIs(t, err, limitcaller.ErrConnectionQueueingTimeout) diff --git a/network/wsNetwork.go b/network/wsNetwork.go index ecb636c8e2..c67200f01b 100644 --- a/network/wsNetwork.go +++ b/network/wsNetwork.go @@ -230,10 +230,9 @@ type WebsocketNetwork struct { // number of throttled outgoing connections "slots" needed to be populated. throttledOutgoingConnections atomic.Int32 - // transport and dialer are customized to limit the number of + // dialer is customized to limit the number of // connection in compliance with connectionsRateLimitingCount. - transport limitcaller.RateLimitingTransport - dialer limitcaller.Dialer + dialer limitcaller.Dialer // messagesOfInterest specifies the message types that this node // wants to receive. nil means default. non-nil causes this @@ -565,9 +564,7 @@ func (wn *WebsocketNetwork) setup() { if wn.nodeInfo == nil { wn.nodeInfo = &nopeNodeInfo{} } - maxIdleConnsPerHost := int(wn.config.ConnectionsRateLimitingCount) wn.dialer = limitcaller.MakeRateLimitingDialer(wn.phonebook, preferredResolver) - wn.transport = limitcaller.MakeRateLimitingTransport(wn.phonebook, limitcaller.DefaultQueueingTimeout, &wn.dialer, maxIdleConnsPerHost) wn.upgrader.ReadBufferSize = 4096 wn.upgrader.WriteBufferSize = 4096 @@ -1975,8 +1972,18 @@ func (wn *WebsocketNetwork) numOutgoingPending() int { // GetHTTPClient returns a http.Client with a suitable for the network Transport // that would also limit the number of outgoing connections. func (wn *WebsocketNetwork) GetHTTPClient(address string) (*http.Client, error) { + url, err := addr.ParseHostOrURL(address) + if err != nil { + return nil, err + } + + maxIdleConnsPerHost := int(wn.config.ConnectionsRateLimitingCount) + rltr := limitcaller.MakeRateLimitingBoundTransport(wn.phonebook, limitcaller.DefaultQueueingTimeout, &wn.dialer, maxIdleConnsPerHost, url.Host) return &http.Client{ - Transport: &HTTPPAddressBoundTransport{address, &wn.transport}, + Transport: &HTTPPAddressBoundTransport{ + address, + &rltr, + }, }, nil } diff --git a/network/wsNetwork_test.go b/network/wsNetwork_test.go index 0128c28fc2..91983dfa20 100644 --- a/network/wsNetwork_test.go +++ b/network/wsNetwork_test.go @@ -4601,3 +4601,42 @@ func TestHTTPPAddressBoundTransport(t *testing.T) { } } } + +// TestWebsocketNetworkHTTPClient checks ws net HTTP client can connect to another node +// with out unexpected errors +func TestWebsocketNetworkHTTPClient(t *testing.T) { + partitiontest.PartitionTest(t) + t.Parallel() + + netA := makeTestWebsocketNode(t) + err := netA.Start() + require.NoError(t, err) + defer netStop(t, netA, "A") + + netB := makeTestWebsocketNodeWithConfig(t, defaultConfig) + + addr, ok := netA.Address() + require.True(t, ok) + + c, err := netB.GetHTTPClient(addr) + require.NoError(t, err) + + netA.RegisterHTTPHandlerFunc("/handled", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + resp, err := c.Do(&http.Request{URL: &url.URL{Path: "/handled"}}) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + resp, err = c.Do(&http.Request{URL: &url.URL{Path: "/test"}}) + require.NoError(t, err) + require.Equal(t, http.StatusNotFound, resp.StatusCode) // no such handler + + resp, err = c.Do(&http.Request{URL: &url.URL{Path: "/v1/" + genesisID + "/gossip"}}) + require.NoError(t, err) + require.Equal(t, http.StatusPreconditionFailed, resp.StatusCode) // not enough ws peer headers + + _, err = netB.GetHTTPClient("invalid") + require.Error(t, err) +}