diff --git a/core/src/message/handlers/connection.rs b/core/src/message/handlers/connection.rs index ce50764e2..d510b2792 100644 --- a/core/src/message/handlers/connection.rs +++ b/core/src/message/handlers/connection.rs @@ -927,58 +927,4 @@ pub mod tests { Ok(()) } - - #[tokio::test] - async fn test_already_connect_fixture() -> Result<()> { - // NodeA-NodeB-NodeC - let keys = gen_ordered_keys(3); - let (key1, key2, key3) = (keys[0], keys[1], keys[2]); - let (node1, _path1) = prepare_node(key1).await; - let (node2, _path2) = prepare_node(key2).await; - let (node3, _path3) = prepare_node(key3).await; - test_only_two_nodes_establish_connection(&node1, &node2).await?; - assert_no_more_msg(&node1, &node2, &node3).await; - - test_only_two_nodes_establish_connection(&node3, &node2).await?; - assert_no_more_msg(&node1, &node2, &node3).await; - // Node 1 -- Node 2 -- Node 3 - println!("node1 connect node2 twice here"); - let _ = node1.connect(node3.did()).await.unwrap(); - let _ = node1.connect(node3.did()).await.unwrap(); - // ConnectNodeSend - let _ = node2.listen_once().await.unwrap(); - let _ = node2.listen_once().await.unwrap(); - // ConnectNodeSend - let _ = node3.listen_once().await.unwrap(); - let _ = node3.listen_once().await.unwrap(); - // ConnectNodeReport - // `self.push_pending_transport(&trans)?;` - let _ = node2.listen_once().await.unwrap(); - let _ = node2.listen_once().await.unwrap(); - // ConnectNodeReport - // self.register(&relay.sender(), transport).await - let _ = node1.listen_once().await.unwrap(); - let _ = node1.listen_once().await.unwrap(); - println!("wait for handshake here"); - sleep(Duration::from_secs(3)).await; - // transport got from node1 for node3 - // transport got from node3 for node - // JoinDHT twice here - let ev3 = node3.listen_once().await.unwrap().0; - assert!(matches!(ev3.data, Message::JoinDHT(_))); - let _ = node3.listen_once().await.is_none(); - - // JoinDHT twice here - let ev1 = node1.listen_once().await.unwrap().0; - assert!(matches!(ev1.data, Message::JoinDHT(_))); - let _ = node1.listen_once().await.is_none(); - - let t1_3 = node1.get_connection(node3.did()).unwrap(); - assert!(t1_3.is_connected().await); - - let t3_1 = node3.get_connection(node1.did()).unwrap(); - assert!(t3_1.is_connected().await); - - Ok(()) - } } diff --git a/core/src/swarm/impls.rs b/core/src/swarm/impls.rs index e0ad7fc5e..d2179b695 100644 --- a/core/src/swarm/impls.rs +++ b/core/src/swarm/impls.rs @@ -331,12 +331,7 @@ impl ConnectionManager for Swarm { /// else try prepare offer and establish connection by dht. /// This function may returns a pending connection or connected connection. async fn connect(&self, did: Did) -> Result { - if let Some(t) = self.get_and_check_connection(did).await { - return Ok(t); - } - tracing::info!("Try connect Did {:?}", &did); - let conn = self.new_connection(did).await?; let offer = conn.webrtc_create_offer().await.map_err(Error::Transport)?; diff --git a/transport/src/connections/native_webrtc/mod.rs b/transport/src/connections/native_webrtc/mod.rs index 120c05462..3e1855ff7 100644 --- a/transport/src/connections/native_webrtc/mod.rs +++ b/transport/src/connections/native_webrtc/mod.rs @@ -3,6 +3,7 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; +use dashmap::mapref::entry::Entry; use webrtc::data_channel::data_channel_message::DataChannelMessage; use webrtc::data_channel::data_channel_state::RTCDataChannelState; use webrtc::data_channel::RTCDataChannel; @@ -142,6 +143,17 @@ impl SharedTransport for Transport { where CE: std::error::Error + Send + Sync + 'static, { + if let Ok(existed_conn) = self.get_connection(cid) { + if matches!( + existed_conn.webrtc_connection_state(), + WebrtcConnectionState::New + | WebrtcConnectionState::Connecting + | WebrtcConnectionState::Connected + ) { + return Err(Error::ConnectionAlreadyExists(cid.to_string())); + } + } + // // Setup webrtc connection env // @@ -221,7 +233,34 @@ impl SharedTransport for Transport { // Construct the Connection // let conn = WebrtcConnection::new(webrtc_conn, webrtc_data_channel); - self.connections.insert(cid.to_string(), conn.clone()); + + // + // Safely insert + // + // The implementation of match statement refers to Entry::insert in dashmap. + // An extra check is added to see if the connection is already connected. + // See also: https://docs.rs/dashmap/latest/dashmap/mapref/entry/enum.Entry.html#method.insert + // + let Some(entry) = self.connections.try_entry(cid.to_string()) else { + return Err(Error::ConnectionAlreadyExists(cid.to_string())); + }; + match entry { + Entry::Occupied(mut entry) => { + let existed_conn = entry.get(); + if matches!( + existed_conn.webrtc_connection_state(), + WebrtcConnectionState::New + | WebrtcConnectionState::Connecting + | WebrtcConnectionState::Connected + ) { + return Err(Error::ConnectionAlreadyExists(cid.to_string())); + } + + entry.insert(conn.clone()); + entry.into_ref() + } + Entry::Vacant(entry) => entry.insert(conn.clone()), + }; Ok(conn) } diff --git a/transport/src/error.rs b/transport/src/error.rs index 109841fd2..784664461 100644 --- a/transport/src/error.rs +++ b/transport/src/error.rs @@ -33,6 +33,9 @@ pub enum Error { #[error("WebRTC local SDP generation error")] WebrtcLocalSdpGenerationError, + #[error("Connection {0} already exists")] + ConnectionAlreadyExists(String), + #[error("Connection {0} not found, should handshake first")] ConnectionNotFound(String), }