Skip to content

Commit

Permalink
Prevent create connection when existed
Browse files Browse the repository at this point in the history
  • Loading branch information
Ma233 committed Sep 4, 2023
1 parent 0370019 commit 8303909
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 60 deletions.
54 changes: 0 additions & 54 deletions core/src/message/handlers/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -927,58 +927,4 @@ pub mod tests {

Ok(())
}

#[tokio::test]
async fn test_already_connect_fixture() -> Result<()> {
// NodeA-NodeB-NodeC
let keys = gen_ordered_keys(3);
let (key1, key2, key3) = (keys[0], keys[1], keys[2]);
let (node1, _path1) = prepare_node(key1).await;
let (node2, _path2) = prepare_node(key2).await;
let (node3, _path3) = prepare_node(key3).await;
test_only_two_nodes_establish_connection(&node1, &node2).await?;
assert_no_more_msg(&node1, &node2, &node3).await;

test_only_two_nodes_establish_connection(&node3, &node2).await?;
assert_no_more_msg(&node1, &node2, &node3).await;
// Node 1 -- Node 2 -- Node 3
println!("node1 connect node2 twice here");
let _ = node1.connect(node3.did()).await.unwrap();
let _ = node1.connect(node3.did()).await.unwrap();
// ConnectNodeSend
let _ = node2.listen_once().await.unwrap();
let _ = node2.listen_once().await.unwrap();
// ConnectNodeSend
let _ = node3.listen_once().await.unwrap();
let _ = node3.listen_once().await.unwrap();
// ConnectNodeReport
// `self.push_pending_transport(&trans)?;`
let _ = node2.listen_once().await.unwrap();
let _ = node2.listen_once().await.unwrap();
// ConnectNodeReport
// self.register(&relay.sender(), transport).await
let _ = node1.listen_once().await.unwrap();
let _ = node1.listen_once().await.unwrap();
println!("wait for handshake here");
sleep(Duration::from_secs(3)).await;
// transport got from node1 for node3
// transport got from node3 for node
// JoinDHT twice here
let ev3 = node3.listen_once().await.unwrap().0;
assert!(matches!(ev3.data, Message::JoinDHT(_)));
let _ = node3.listen_once().await.is_none();

// JoinDHT twice here
let ev1 = node1.listen_once().await.unwrap().0;
assert!(matches!(ev1.data, Message::JoinDHT(_)));
let _ = node1.listen_once().await.is_none();

let t1_3 = node1.get_connection(node3.did()).unwrap();
assert!(t1_3.is_connected().await);

let t3_1 = node3.get_connection(node1.did()).unwrap();
assert!(t3_1.is_connected().await);

Ok(())
}
}
5 changes: 0 additions & 5 deletions core/src/swarm/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,7 @@ impl ConnectionManager for Swarm {
/// else try prepare offer and establish connection by dht.
/// This function may returns a pending connection or connected connection.
async fn connect(&self, did: Did) -> Result<Connection> {
if let Some(t) = self.get_and_check_connection(did).await {
return Ok(t);
}

tracing::info!("Try connect Did {:?}", &did);

let conn = self.new_connection(did).await?;

let offer = conn.webrtc_create_offer().await.map_err(Error::Transport)?;
Expand Down
41 changes: 40 additions & 1 deletion transport/src/connections/native_webrtc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::time::Duration;

use async_trait::async_trait;
use bytes::Bytes;
use dashmap::mapref::entry::Entry;
use webrtc::data_channel::data_channel_message::DataChannelMessage;
use webrtc::data_channel::data_channel_state::RTCDataChannelState;
use webrtc::data_channel::RTCDataChannel;
Expand Down Expand Up @@ -142,6 +143,17 @@ impl SharedTransport for Transport<WebrtcConnection> {
where
CE: std::error::Error + Send + Sync + 'static,
{
if let Ok(existed_conn) = self.get_connection(cid) {
if matches!(
existed_conn.webrtc_connection_state(),
WebrtcConnectionState::New
| WebrtcConnectionState::Connecting
| WebrtcConnectionState::Connected
) {
return Err(Error::ConnectionAlreadyExists(cid.to_string()));
}
}

//
// Setup webrtc connection env
//
Expand Down Expand Up @@ -221,7 +233,34 @@ impl SharedTransport for Transport<WebrtcConnection> {
// Construct the Connection
//
let conn = WebrtcConnection::new(webrtc_conn, webrtc_data_channel);
self.connections.insert(cid.to_string(), conn.clone());

//
// Safely insert
//
// The implementation of match statement refers to Entry::insert in dashmap.
// An extra check is added to see if the connection is already connected.
// See also: https://docs.rs/dashmap/latest/dashmap/mapref/entry/enum.Entry.html#method.insert
//
let Some(entry) = self.connections.try_entry(cid.to_string()) else {
return Err(Error::ConnectionAlreadyExists(cid.to_string()));
};
match entry {
Entry::Occupied(mut entry) => {
let existed_conn = entry.get();
if matches!(
existed_conn.webrtc_connection_state(),
WebrtcConnectionState::New
| WebrtcConnectionState::Connecting
| WebrtcConnectionState::Connected
) {
return Err(Error::ConnectionAlreadyExists(cid.to_string()));
}

entry.insert(conn.clone());
entry.into_ref()
}
Entry::Vacant(entry) => entry.insert(conn.clone()),
};

Ok(conn)
}
Expand Down
3 changes: 3 additions & 0 deletions transport/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ pub enum Error {
#[error("WebRTC local SDP generation error")]
WebrtcLocalSdpGenerationError,

#[error("Connection {0} already exists")]
ConnectionAlreadyExists(String),

#[error("Connection {0} not found, should handshake first")]
ConnectionNotFound(String),
}

0 comments on commit 8303909

Please sign in to comment.