diff --git a/src/api/rps.rs b/src/api/rps.rs index e61a9fe..7cfadcf 100644 --- a/src/api/rps.rs +++ b/src/api/rps.rs @@ -9,8 +9,11 @@ use onion::{Peer, RsaPrivateKey, RsaPublicKey}; use std::net::SocketAddr; use tokio::net::TcpStream; use tokio::sync::mpsc; +use tokio::time; +use tokio::time::Duration; const PEER_BUFFER_SIZE: usize = 20; +const QUERY_TIMEOUT: Duration = Duration::from_secs(2); pub enum RpsModule { Socket(SocketRpsModule), @@ -73,20 +76,40 @@ impl SocketRpsModule { async fn query(&mut self) -> Result { self.socket.write(RpsRequest::Query).await?; - if let Some(msg) = self.socket.read_next().await? { - match msg { - RpsResponse::Peer(_port, portmap, peer_addr, peer_hostkey) => { - let (_, peer_port) = portmap - .iter() - .find(|(m, _)| *m == Module::Onion) - .ok_or_else(|| anyhow!("Peer does not expose onion port"))?; - let peer_addr = SocketAddr::new(peer_addr, *peer_port); - let peer_hostkey = RsaPublicKey::new(peer_hostkey.as_ref()); - Ok(Peer::new(peer_addr, peer_hostkey)) - } + let msg = time::timeout(QUERY_TIMEOUT, self.socket.read_next()) + .await + .map_err(|_| anyhow!("RPS query timed out"))? + .map_err(|e| anyhow!("RPS query failed: {}", e))?; + + match msg { + RpsResponse::Peer(_port, portmap, peer_addr, peer_hostkey) => { + let (_, peer_port) = portmap + .iter() + .find(|(m, _)| *m == Module::Onion) + .ok_or_else(|| anyhow!("Peer does not expose onion port"))?; + let peer_addr = SocketAddr::new(peer_addr, *peer_port); + let peer_hostkey = RsaPublicKey::new(peer_hostkey.as_ref()); + Ok(Peer::new(peer_addr, peer_hostkey)) } - } else { - Err(anyhow!("rps query failed")) } } } + +#[cfg(test)] +mod tests { + use crate::api::config::RpsConfig; + use crate::api::rps::RpsModule; + + #[tokio::test] + #[ignore = "requires a running RPS instance listening on 127.0.0.1:7101"] + async fn test_rps_query() { + let config = RpsConfig { + api_address: Some("127.0.0.1:7101".parse().unwrap()), + peers: None, + }; + + let mut rps = RpsModule::new(&config).await.unwrap(); + println!("Connected to RPS"); + println!("{:?}", rps.query().await); + } +} diff --git a/src/api/socket.rs b/src/api/socket.rs index 439c586..7f4ca4c 100644 --- a/src/api/socket.rs +++ b/src/api/socket.rs @@ -18,7 +18,7 @@ impl ApiSocket { } impl ApiSocket { - pub(crate) async fn read_next>(&mut self) -> Result> { + pub async fn read_next>(&mut self) -> Result { let mut size_buf = [0u8; 2]; self.stream.read_exact(&mut size_buf).await?; let size = u16::from_be_bytes(size_buf) as usize; @@ -27,12 +27,12 @@ impl ApiSocket { self.buf[0] = size_buf[0]; self.buf[1] = size_buf[1]; self.stream.read_exact(&mut self.buf[2..]).await?; - Ok(Some(M::try_read_from(&mut self.buf)?)) + Ok(M::try_read_from(&mut self.buf)?) } } impl ApiSocket { - pub(crate) async fn write(&mut self, message: M) -> Result<()> { + pub async fn write(&mut self, message: M) -> Result<()> { self.buf.clear(); self.buf.reserve(message.size()); message.write_to(&mut self.buf); diff --git a/src/main.rs b/src/main.rs index f9a4bb7..c72ebf7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -74,7 +74,8 @@ impl OnionModule { .insert(client_addr, ApiSocket::new(write_stream)); let mut socket = ApiSocket::new(read_stream); - while let Some(msg) = socket.read_next::().await? { + loop { + let msg = socket.read_next::().await?; trace!("Handling {:?}", msg); let _msg_id = msg.id(); match msg { @@ -128,7 +129,6 @@ impl OnionModule { } } } - Ok(()) } /// Handles P2P protocol events and notifies interested API clients