Skip to content

Commit dee38a1

Browse files
committed
Address PR comments
1 parent b94dd73 commit dee38a1

File tree

1 file changed

+53
-18
lines changed

1 file changed

+53
-18
lines changed

trin-core/src/utp/stream.rs

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,23 @@ pub fn rand() -> u16 {
6666
rand::thread_rng().gen()
6767
}
6868

69+
/// Connection key for storing active uTP connections
70+
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
71+
pub struct ConnectionKey {
72+
node_id: NodeId,
73+
conn_id_recv: ConnId,
74+
}
75+
76+
impl ConnectionKey {
77+
fn new(node_id: NodeId, conn_id_recv: ConnId) -> Self {
78+
Self {
79+
node_id,
80+
conn_id_recv,
81+
}
82+
}
83+
}
84+
85+
/// uTP socket connection state
6986
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
7087
pub enum SocketState {
7188
Uninitialized,
@@ -103,7 +120,7 @@ pub struct UtpListener {
103120
/// Base discv5 layer
104121
discovery: Arc<Discovery>,
105122
/// Store all active connections
106-
utp_connections: HashMap<NodeId, UtpSocket>,
123+
utp_connections: HashMap<ConnectionKey, UtpSocket>,
107124
/// uTP connection ids to listen for
108125
listening: HashMap<ConnId, UtpMessageId>,
109126
/// Receiver for uTP events sent from the main portal event handler
@@ -159,9 +176,14 @@ impl UtpListener {
159176

160177
match Packet::try_from(payload) {
161178
Ok(packet) => {
179+
let connection_id = packet.connection_id();
180+
162181
match packet.get_type() {
163182
PacketType::Reset => {
164-
if let Some(conn) = self.utp_connections.get_mut(node_id) {
183+
if let Some(conn) = self
184+
.utp_connections
185+
.get_mut(&ConnectionKey::new(*node_id, connection_id))
186+
{
165187
if conn.discv5_tx.send(packet).is_ok() {
166188
let mut buf = [0; BUF_SIZE];
167189
if let Err(msg) = conn.recv(&mut buf).await {
@@ -188,7 +210,10 @@ impl UtpListener {
188210
return;
189211
}
190212

191-
self.utp_connections.insert(*node_id, conn.clone());
213+
self.utp_connections.insert(
214+
ConnectionKey::new(*node_id, conn.receiver_connection_id),
215+
conn.clone(),
216+
);
192217

193218
// Get ownership of FindContentData and re-add the receiver connection
194219
let utp_message_id = self.listening.remove(&conn.sender_connection_id);
@@ -230,7 +255,10 @@ impl UtpListener {
230255
}
231256
// Receive DATA and FIN packets
232257
PacketType::Data => {
233-
if let Some(conn) = self.utp_connections.get_mut(node_id) {
258+
if let Some(conn) = self
259+
.utp_connections
260+
.get_mut(&ConnectionKey::new(*node_id, connection_id))
261+
{
234262
if conn.discv5_tx.send(packet.clone()).is_err() {
235263
error!("Unable to send DATA packet to uTP stream handler");
236264
return;
@@ -246,7 +274,10 @@ impl UtpListener {
246274
}
247275
}
248276
PacketType::Fin => {
249-
if let Some(conn) = self.utp_connections.get_mut(node_id) {
277+
if let Some(conn) = self
278+
.utp_connections
279+
.get_mut(&ConnectionKey::new(*node_id, connection_id))
280+
{
250281
if conn.discv5_tx.send(packet).is_err() {
251282
error!("Unable to send FIN packet to uTP stream handler");
252283
return;
@@ -259,7 +290,10 @@ impl UtpListener {
259290
}
260291
}
261292
PacketType::State => {
262-
if let Some(conn) = self.utp_connections.get_mut(node_id) {
293+
if let Some(conn) = self
294+
.utp_connections
295+
.get_mut(&ConnectionKey::new(*node_id, connection_id))
296+
{
263297
if conn.discv5_tx.send(packet).is_err() {
264298
error!("Unable to send STATE packet to uTP stream handler");
265299
}
@@ -311,7 +345,10 @@ impl UtpListener {
311345
if let Some(enr) = self.discovery.discv5.find_enr(&node_id) {
312346
let mut conn = UtpSocket::new(Arc::clone(&self.discovery), enr);
313347
conn.make_connection(connection_id).await;
314-
self.utp_connections.insert(node_id, conn.clone());
348+
self.utp_connections.insert(
349+
ConnectionKey::new(node_id, conn.receiver_connection_id),
350+
conn.clone(),
351+
);
315352
Ok(conn)
316353
} else {
317354
Err(anyhow!("Trying to connect to unknow Enr"))
@@ -1353,12 +1390,11 @@ mod tests {
13531390
let enr = discv5.discv5.local_enr();
13541391
discv5.start().await.unwrap();
13551392

1356-
let discv5_arc = Arc::new(discv5);
1357-
let discv5_arc_clone = Arc::clone(&discv5_arc);
1393+
let discv5 = Arc::new(discv5);
13581394

1359-
let conn = UtpSocket::new(discv5_arc, enr);
1395+
let conn = UtpSocket::new(Arc::clone(&discv5), enr);
13601396
// TODO: Create `Discv5Socket` struct to encapsulate all socket logic
1361-
spawn_socket_recv(discv5_arc_clone, conn.clone());
1397+
spawn_socket_recv(Arc::clone(&discv5), conn.clone());
13621398

13631399
conn
13641400
}
@@ -1374,18 +1410,17 @@ mod tests {
13741410
let mut discv5 = Discovery::new(config).unwrap();
13751411
discv5.start().await.unwrap();
13761412

1377-
let discv5_arc = Arc::new(discv5);
1378-
let discv5_arc_clone = Arc::clone(&discv5_arc);
1413+
let discv5 = Arc::new(discv5);
13791414

1380-
let conn = UtpSocket::new(Arc::clone(&discv5_arc), connected_to);
1381-
spawn_socket_recv(discv5_arc_clone, conn.clone());
1415+
let conn = UtpSocket::new(Arc::clone(&discv5), connected_to);
1416+
spawn_socket_recv(Arc::clone(&discv5), conn.clone());
13821417

1383-
(discv5_arc.local_enr(), conn)
1418+
(discv5.local_enr(), conn)
13841419
}
13851420

1386-
fn spawn_socket_recv(discv5_arc_clone: Arc<Discovery>, conn: UtpSocket) {
1421+
fn spawn_socket_recv(discv5: Arc<Discovery>, conn: UtpSocket) {
13871422
tokio::spawn(async move {
1388-
let mut receiver = discv5_arc_clone.discv5.event_stream().await.unwrap();
1423+
let mut receiver = discv5.discv5.event_stream().await.unwrap();
13891424
while let Some(event) = receiver.recv().await {
13901425
match event {
13911426
Discv5Event::TalkRequest(request) => {

0 commit comments

Comments
 (0)