Skip to content

Commit

Permalink
network: discard unrequested or stale block messages (#5431)
Browse files Browse the repository at this point in the history
  • Loading branch information
iansuvak authored Jun 20, 2023
1 parent ae02370 commit f83a656
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 4 deletions.
114 changes: 110 additions & 4 deletions network/wsNetwork_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3899,8 +3899,7 @@ func TestMaxHeaderSize(t *testing.T) {
netA.wsMaxHeaderBytes = wsMaxHeaderBytes
netA.wg.Add(1)
netA.tryConnect(addrB, gossipB)
time.Sleep(250 * time.Millisecond)
assert.Equal(t, 1, len(netA.peers))
require.Eventually(t, func() bool { return netA.NumPeers() == 1 }, 500*time.Millisecond, 25*time.Millisecond)

netA.removePeer(netA.peers[0], disconnectReasonNone)
assert.Zero(t, len(netA.peers))
Expand All @@ -3922,8 +3921,7 @@ func TestMaxHeaderSize(t *testing.T) {
netA.wsMaxHeaderBytes = 0
netA.wg.Add(1)
netA.tryConnect(addrB, gossipB)
time.Sleep(250 * time.Millisecond)
assert.Equal(t, 1, len(netA.peers))
require.Eventually(t, func() bool { return netA.NumPeers() == 1 }, 500*time.Millisecond, 25*time.Millisecond)
}

func TestTryConnectEarlyWrite(t *testing.T) {
Expand Down Expand Up @@ -3981,6 +3979,114 @@ func TestTryConnectEarlyWrite(t *testing.T) {
assert.Equal(t, uint64(1), netA.peers[0].miMessageCount)
}

// Test functionality that allows a node to discard a block response that it did not request or that arrived too late.
// Both cases are tested here by having A send unexpected, late responses to nodes B and C respectively.
func TestDiscardUnrequestedBlockResponse(t *testing.T) {
partitiontest.PartitionTest(t)

netA := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netA"})
netA.config.GossipFanout = 1

netB := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netB"})
netB.config.GossipFanout = 1

netC := makeTestWebsocketNode(t, testWebsocketLogNameOption{"netC"})
netC.config.GossipFanout = 1

netA.Start()
defer netA.Stop()
netB.Start()
defer netB.Stop()

addrB, ok := netB.Address()
require.True(t, ok)
gossipB, err := netB.addrToGossipAddr(addrB)
require.NoError(t, err)

netA.wg.Add(1)
netA.tryConnect(addrB, gossipB)
require.Eventually(t, func() bool { return netA.NumPeers() == 1 }, 500*time.Millisecond, 25*time.Millisecond)

// send an unrequested block response
msg := make([]sendMessage, 1, 1)
msg[0] = sendMessage{
data: append([]byte(protocol.TopicMsgRespTag), []byte("foo")...),
enqueued: time.Now(),
peerEnqueued: time.Now(),
ctx: context.Background(),
}
netA.peers[0].sendBufferBulk <- sendMessages{msgs: msg}
require.Eventually(t,
func() bool {
return networkConnectionsDroppedTotal.GetUint64ValueForLabels(map[string]string{"reason": "unrequestedTS"}) == 1
},
1*time.Second,
50*time.Millisecond,
)

// Stop and confirm that we hit the case of disconnecting a peer for sending an unrequested block response
require.Zero(t, netB.NumPeers())

netC.Start()
defer netC.Stop()

addrC, ok := netC.Address()
require.True(t, ok)
gossipC, err := netC.addrToGossipAddr(addrC)
require.NoError(t, err)

netA.wg.Add(1)
netA.tryConnect(addrC, gossipC)
require.Eventually(t, func() bool { return netA.NumPeers() == 1 }, 500*time.Millisecond, 25*time.Millisecond)

ctx, cancel := context.WithCancel(context.Background())
topics := Topics{
MakeTopic("requestDataType",
[]byte("fake block and cert value")),
MakeTopic(
"blockData",
[]byte("fake round value")),
}
// Send a request for a block and cancel it after the handler has been registered
go func() {
netC.peers[0].Request(ctx, protocol.UniEnsBlockReqTag, topics)
}()
require.Eventually(
t,
func() bool { return netC.peers[0].lenResponseChannels() > 0 },
1*time.Second,
50*time.Millisecond,
)
cancel()

// confirm that the request was cancelled but that we have registered that we have sent a request
require.Eventually(
t,
func() bool { return netC.peers[0].lenResponseChannels() == 0 },
500*time.Millisecond,
20*time.Millisecond,
)
require.Equal(t, atomic.LoadInt64(&netC.peers[0].outstandingTopicRequests), int64(1))

// Create a buffer to monitor log output from netC
logBuffer := bytes.NewBuffer(nil)
netC.log.SetOutput(logBuffer)

// send a late TS response from A -> C
netA.peers[0].sendBufferBulk <- sendMessages{msgs: msg}
require.Eventually(
t,
func() bool { return atomic.LoadInt64(&netC.peers[0].outstandingTopicRequests) == int64(0) },
500*time.Millisecond,
20*time.Millisecond,
)

// Stop and confirm that we hit the case of disconnecting a peer for sending a stale block response
netC.Stop()
lg := logBuffer.String()
require.Contains(t, lg, "wsPeer readLoop: received a TS response for a stale request ")
}

func customNetworkIDGen(networkID protocol.NetworkID) *rapid.Generator[protocol.NetworkID] {
return rapid.Custom(func(t *rapid.T) protocol.NetworkID {
// Unused/satisfying rapid requirement
Expand Down
36 changes: 36 additions & 0 deletions network/wsPeer.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ const disconnectRequestReceived disconnectReason = "DisconnectRequest"
const disconnectStaleWrite disconnectReason = "DisconnectStaleWrite"
const disconnectDuplicateConnection disconnectReason = "DuplicateConnection"
const disconnectBadIdentityData disconnectReason = "BadIdentityData"
const disconnectUnexpectedTopicResp disconnectReason = "UnexpectedTopicResp"

// Response is the structure holding the response from the server
type Response struct {
Expand All @@ -182,6 +183,10 @@ type wsPeer struct {
// we want this to be a 64-bit aligned for atomics support on 32bit platforms.
lastPacketTime int64

// outstandingTopicRequests is an atomic counter for the number of outstanding block requests we've made out to this peer
// if a peer sends more blocks than we've requested, we'll disconnect from it.
outstandingTopicRequests int64

// intermittentOutgoingMessageEnqueueTime contains the UnixNano of the message's enqueue time that is currently being written to the
// peer, or zero if no message is being written.
intermittentOutgoingMessageEnqueueTime int64
Expand Down Expand Up @@ -502,6 +507,29 @@ func (wp *wsPeer) readLoop() {
return
}
msg.Tag = Tag(string(tag[:]))

// Skip the message if it's a response to a request we didn't make or has timed out
if msg.Tag == protocol.TopicMsgRespTag && wp.lenResponseChannels() == 0 {
atomic.AddInt64(&wp.outstandingTopicRequests, -1)

// This peers has sent us more responses than we have requested. This is a protocol violation and we should disconnect.
if atomic.LoadInt64(&wp.outstandingTopicRequests) < 0 {
wp.net.log.Errorf("wsPeer readloop: peer %s sent TS response without a request", wp.conn.RemoteAddr().String())
networkConnectionsDroppedTotal.Inc(map[string]string{"reason": "unrequestedTS"})
cleanupCloseError = disconnectUnexpectedTopicResp
return
}
var n int64
// Peer sent us a response to a request we made but we've already timed out -- discard
n, err = io.Copy(io.Discard, reader)
if err != nil {
wp.net.log.Infof("wsPeer readloop: could not discard timed-out TS message from %s : %s", wp.conn.RemoteAddr().String(), err)
wp.reportReadErr(err)
return
}
wp.net.log.Warnf("wsPeer readLoop: received a TS response for a stale request from %s. %d bytes discarded", wp.conn.RemoteAddr().String(), n)
continue
}
slurper.Reset()
err = slurper.Read(reader)
if err != nil {
Expand Down Expand Up @@ -544,6 +572,7 @@ func (wp *wsPeer) readLoop() {
}
continue
case protocol.TopicMsgRespTag: // Handle Topic message
atomic.AddInt64(&wp.outstandingTopicRequests, -1)
topics, err := UnmarshallTopics(msg.Data)
if err != nil {
wp.net.log.Warnf("wsPeer readLoop: could not read the message from: %s %s", wp.conn.RemoteAddr().String(), err)
Expand Down Expand Up @@ -950,6 +979,7 @@ func (wp *wsPeer) Request(ctx context.Context, tag Tag, topics Topics) (resp *Re
ctx: context.Background()}
select {
case wp.sendBufferBulk <- sendMessages{msgs: msg}:
atomic.AddInt64(&wp.outstandingTopicRequests, 1)
case <-wp.closing:
e = fmt.Errorf("peer closing %s", wp.conn.RemoteAddr().String())
return
Expand Down Expand Up @@ -977,6 +1007,12 @@ func (wp *wsPeer) makeResponseChannel(key uint64) (responseChannel chan *Respons
return newChan
}

func (wp *wsPeer) lenResponseChannels() int {
wp.responseChannelsMutex.Lock()
defer wp.responseChannelsMutex.Unlock()
return len(wp.responseChannels)
}

// getAndRemoveResponseChannel returns the channel and deletes the channel from the map
func (wp *wsPeer) getAndRemoveResponseChannel(key uint64) (respChan chan *Response, found bool) {
wp.responseChannelsMutex.Lock()
Expand Down
13 changes: 13 additions & 0 deletions util/metrics/counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ func (counter *Counter) GetUint64Value() (x uint64) {
return atomic.LoadUint64(&counter.intValue)
}

// GetUint64ValueForLabels returns the value of the counter for the given labels or 0 if it's not found.
func (counter *Counter) GetUint64ValueForLabels(labels map[string]string) uint64 {
counter.Lock()
defer counter.Unlock()

labelIndex := counter.findLabelIndex(labels)
counterIdx, has := counter.valuesIndices[labelIndex]
if !has {
return 0
}
return counter.values[counterIdx].counter
}

func (counter *Counter) fastAddUint64(x uint64) {
if atomic.AddUint64(&counter.intValue, x) == x {
// What we just added is the whole value, this
Expand Down
20 changes: 20 additions & 0 deletions util/metrics/counter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,23 @@ func TestGetValue(t *testing.T) {
c.Inc(nil)
require.Equal(t, uint64(2), c.GetUint64Value())
}

func TestGetValueForLabels(t *testing.T) {
partitiontest.PartitionTest(t)

c := MakeCounter(MetricName{Name: "testname", Description: "testhelp"})
c.Deregister(nil)

labels := map[string]string{"a": "b"}
require.Equal(t, uint64(0), c.GetUint64ValueForLabels(labels))
c.Inc(labels)
require.Equal(t, uint64(1), c.GetUint64ValueForLabels(labels))
c.Inc(labels)
require.Equal(t, uint64(2), c.GetUint64ValueForLabels(labels))
// confirm that the value is not shared between labels
c.Inc(nil)
require.Equal(t, uint64(2), c.GetUint64ValueForLabels(labels))
labels2 := map[string]string{"a": "c"}
c.Inc(labels2)
require.Equal(t, uint64(1), c.GetUint64ValueForLabels(labels2))
}

0 comments on commit f83a656

Please sign in to comment.