Skip to content

Commit

Permalink
Merge pull request #83 from nightly-labs/dashmap-replacement
Browse files Browse the repository at this point in the history
Removal of DashMap from the repository
  • Loading branch information
NorbertBodziony authored Feb 2, 2024
2 parents 0692759 + dce8e32 commit 5b560cd
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 77 deletions.
1 change: 0 additions & 1 deletion server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ strum = { workspace = true }
dotenvy = { workspace = true }
log = { workspace = true }
tower = { workspace = true }
hyper = { wokrspace = true }
axum = { workspace = true }
tower-http = { workspace = true }
tracing-subscriber = { workspace = true }
Expand Down
49 changes: 25 additions & 24 deletions server/src/client/client_handler.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
use std::net::SocketAddr;

use axum::{
extract::{
ws::{Message, WebSocket},
ConnectInfo, State, WebSocketUpgrade,
},
response::Response,
};
use futures::StreamExt;
use log::{debug, info};

use crate::{
errors::NightlyError,
state::{
Expand All @@ -25,14 +13,24 @@ use crate::{
client_initialize::ClientInitializeResponse,
client_messages::{ClientToServer, ServerToClient},
connect::ConnectResponse,
drop_sessions::{self, DropSessionsResponse},
drop_sessions::DropSessionsResponse,
get_info::GetInfoResponse,
get_pending_requests::GetPendingRequestsResponse,
get_sessions::GetSessionsResponse,
},
common::{AckMessage, ErrorMessage, SessionStatus},
},
};
use axum::{
extract::{
ws::{Message, WebSocket},
ConnectInfo, State, WebSocketUpgrade,
},
response::Response,
};
use futures::StreamExt;
use log::{debug, info};
use std::net::SocketAddr;

pub async fn on_new_client_connection(
ConnectInfo(ip): ConnectInfo<SocketAddr>,
Expand Down Expand Up @@ -115,7 +113,7 @@ pub async fn client_handler(
Err(_e) => {
let user_disconnected_event =
ServerToApp::UserDisconnectedEvent(UserDisconnectedEvent {});
let user_sessions = client_to_sessions.get_sessions(client_id.clone());
let user_sessions = client_to_sessions.get_sessions(client_id.clone()).await;
for session_id in user_sessions {
let mut sessions = sessions.write().await;
let session = match sessions.get_mut(&session_id) {
Expand All @@ -142,7 +140,7 @@ pub async fn client_handler(
None => {
let user_disconnected_event =
ServerToApp::UserDisconnectedEvent(UserDisconnectedEvent {});
let user_sessions = client_to_sessions.get_sessions(client_id.clone());
let user_sessions = client_to_sessions.get_sessions(client_id.clone()).await;
for session_id in user_sessions {
let mut sessions = sessions.write().await;
let session = match sessions.get_mut(&session_id) {
Expand Down Expand Up @@ -174,7 +172,7 @@ pub async fn client_handler(
Message::Close(None) | Message::Close(Some(_)) => {
let user_disconnected_event =
ServerToApp::UserDisconnectedEvent(UserDisconnectedEvent {});
let user_sessions = client_to_sessions.get_sessions(client_id.clone());
let user_sessions = client_to_sessions.get_sessions(client_id.clone()).await;
for session_id in user_sessions {
let mut sessions = sessions.write().await;
let session = match sessions.get_mut(&session_id) {
Expand Down Expand Up @@ -251,10 +249,12 @@ pub async fn client_handler(
session.send_to_app(app_event).await.unwrap_or_default();

// Insert new session id into client_to_sessions
client_to_sessions.add_session(
connect_request.client_id.clone(),
connect_request.session_id.clone(),
);
client_to_sessions
.add_session(
connect_request.client_id.clone(),
connect_request.session_id.clone(),
)
.await;

let client_reponse = ServerToClient::ConnectResponse(ConnectResponse {
response_id: connect_request.response_id,
Expand Down Expand Up @@ -384,7 +384,7 @@ pub async fn client_handler(
.unwrap_or_default();
}
ClientToServer::GetSessionsRequest(get_sessions_request) => {
let sessions = client_to_sessions.get_sessions(client_id.clone());
let sessions = client_to_sessions.get_sessions(client_id.clone()).await;
let response = ServerToClient::GetSessionsResponse(GetSessionsResponse {
sessions,
response_id: get_sessions_request.response_id,
Expand All @@ -401,9 +401,10 @@ pub async fn client_handler(
if sessions.disconnect_user(session_id.clone()).await.is_ok() {
dropped_sessions.push(session_id.clone());
};
if let Some(sessions) = client_to_sessions.get_mut(&client_id) {
sessions.remove(&session_id);
}

client_to_sessions
.remove_session(client_id.clone(), session_id.clone())
.await;
}
let response = ServerToClient::DropSessionsResponse(DropSessionsResponse {
dropped_sessions,
Expand Down
6 changes: 4 additions & 2 deletions server/src/client/connect_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub async fn connect_session(
Json(request): Json<HttpConnectSessionRequest>,
) -> Result<Json<HttpConnectSessionResponse>, (StatusCode, String)> {
let mut sessions = sessions.write().await;
let mut session = match sessions.get_mut(&request.session_id) {
let session = match sessions.get_mut(&request.session_id) {
Some(session) => session,
None => {
return Err((
Expand Down Expand Up @@ -69,6 +69,8 @@ pub async fn connect_session(
}
};
// Insert new session id into client_to_sessions
client_to_sessions.add_session(request.client_id.clone(), request.session_id.clone());
client_to_sessions
.add_session(request.client_id.clone(), request.session_id.clone())
.await;
return Ok(Json(HttpConnectSessionResponse {}));
}
19 changes: 9 additions & 10 deletions server/src/client/drop_sessions.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
use crate::state::{
ClientId, ClientToSessions, DisconnectUser, ModifySession, SessionId, Sessions,
};
use axum::{extract::State, http::StatusCode, Json};
use serde::{Deserialize, Serialize};
use ts_rs::TS;

use crate::{
state::{ClientId, ClientToSessions, DisconnectUser, SessionId, Sessions},
structs::app_messages::{
app_messages::ServerToApp, user_disconnected_event::UserDisconnectedEvent,
},
};

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TS)]
#[ts(export)]
pub struct HttpDropSessionsRequest {
#[serde(rename = "clientId")]
pub client_id: ClientId,
pub sessions: Vec<SessionId>,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TS)]
#[ts(export)]
pub struct HttpDropSessionsResponse {
#[serde(rename = "droppedSessions")]
pub dropped_sessions: Vec<SessionId>,
}

pub async fn drop_sessions(
State(sessions): State<Sessions>,
State(client_to_sessions): State<ClientToSessions>,
Expand All @@ -33,9 +31,10 @@ pub async fn drop_sessions(
if sessions.disconnect_user(session_id.clone()).await.is_ok() {
dropped_sessions.push(session_id.clone());
};
if let Some(sessions) = client_to_sessions.get_mut(&request.client_id) {
sessions.remove(&session_id);
}

client_to_sessions
.remove_session(request.client_id.clone(), session_id)
.await;
}
Ok(Json(HttpDropSessionsResponse { dropped_sessions }))
}
11 changes: 7 additions & 4 deletions server/src/client/get_pending_requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@ pub async fn get_pending_requests(
NightlyError::UserNotConnected.to_string(),
));
}
let mut pending_requests = Vec::new();
for (key, pending_request) in session.pending_requests.iter() {
pending_requests.push(pending_request.clone());
}

let pending_requests = session
.pending_requests
.values()
.cloned()
.collect::<Vec<_>>();

Ok(Json(HttpGetPendingRequestsResponse { pending_requests }))
}
2 changes: 1 addition & 1 deletion server/src/client/get_sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ pub async fn get_sessions(
State(client_to_sessions): State<ClientToSessions>,
Json(request): Json<HttpGetSessionsRequest>,
) -> Result<Json<HttpGetSessionsResponse>, (StatusCode, String)> {
let sessions = client_to_sessions.get_sessions(request.client_id);
let sessions = client_to_sessions.get_sessions(request.client_id).await;
Ok(Json(HttpGetSessionsResponse { sessions }))
}
6 changes: 2 additions & 4 deletions server/src/handle_error.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use axum::response::IntoResponse;
use hyper::StatusCode;
use crate::errors::NightlyError;
use axum::{http::StatusCode, response::IntoResponse};
use log::error;
use tower::BoxError;

use crate::errors::NightlyError;

pub async fn handle_error(error: BoxError) -> impl IntoResponse {
error!("Request error {:?}", error);
if error.is::<tower::timeout::error::Elapsed>() {
Expand Down
4 changes: 3 additions & 1 deletion server/src/sesssion_cleaner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ pub fn start_cleaning_sessions(sessions: Sessions, client_to_sessions: ClientToS
// Remove session from client_to_sessions
match &session.client_state.client_id {
Some(client_id) => {
client_to_sessions.remove_session(client_id.clone(), session_id.clone());
client_to_sessions
.remove_session(client_id.clone(), session_id.clone())
.await;
}
None => {}
}
Expand Down
85 changes: 56 additions & 29 deletions server/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::{collections::HashMap, sync::Arc};

use crate::structs::{
app_messages::{app_messages::ServerToApp, user_disconnected_event::UserDisconnectedEvent},
client_messages::client_messages::ServerToClient,
Expand All @@ -11,9 +9,13 @@ use axum::extract::{
ws::{Message, WebSocket},
FromRef,
};
use dashmap::{DashMap, DashSet};
use dashmap::DashMap;
use futures::{stream::SplitSink, SinkExt};
use log::info;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::RwLock;

pub type SessionId = String;
Expand All @@ -28,7 +30,7 @@ pub trait DisconnectUser {
impl DisconnectUser for Sessions {
async fn disconnect_user(&self, session_id: SessionId) -> Result<()> {
let mut sessions = self.write().await;
let mut session = match sessions.get_mut(&session_id) {
let session = match sessions.get_mut(&session_id) {
Some(session) => session,
None => return Err(anyhow::anyhow!("Session does not exist")), // Session does not exist
};
Expand Down Expand Up @@ -76,39 +78,59 @@ impl SendToClient for ClientSockets {
}
}
}
pub type ClientToSessions = Arc<DashMap<ClientId, DashSet<SessionId>>>;
pub type ClientToSessions = Arc<RwLock<HashMap<ClientId, RwLock<HashSet<SessionId>>>>>;
#[derive(Clone, FromRef)]
pub struct ServerState {
pub sessions: Sessions,
pub client_to_sockets: ClientSockets, // Holds only live sockets
pub client_to_sessions: ClientToSessions,
}

#[async_trait]
pub trait ModifySession {
fn remove_session(&self, client_id: ClientId, session_id: SessionId);
fn add_session(&self, client_id: ClientId, session_id: SessionId);
fn get_sessions(&self, client_id: ClientId) -> Vec<SessionId>;
async fn remove_session(&self, client_id: ClientId, session_id: SessionId);
async fn add_session(&self, client_id: ClientId, session_id: SessionId);
async fn get_sessions(&self, client_id: ClientId) -> Vec<SessionId>;
}

#[async_trait]
impl ModifySession for ClientToSessions {
fn remove_session(&self, client_id: ClientId, session_id: SessionId) {
let entry = match self.get(&client_id) {
Some(sessions) => sessions,
async fn remove_session(&self, client_id: ClientId, session_id: SessionId) {
let mut clients_write = self.write().await;

let client_sessions_lock = match clients_write.get_mut(&client_id) {
Some(entry) => entry,
None => return,
};
entry.remove(&session_id);
let is_empty = entry.is_empty();
drop(entry); // drop the lock

let mut client_sessions_write = client_sessions_lock.write().await;
client_sessions_write.remove(&session_id);

let is_empty = client_sessions_write.is_empty();
drop(client_sessions_write);

if is_empty {
self.remove(&client_id);
clients_write.remove(&client_id);
}
}
fn add_session(&self, client_id: ClientId, session_id: SessionId) {
self.entry(client_id)
.or_insert_with(|| DashSet::new())
.insert(session_id);

async fn add_session(&self, client_id: ClientId, session_id: SessionId) {
let mut clients_write = self.write().await;
let client_sessions = clients_write
.entry(client_id)
.or_insert_with(|| RwLock::new(HashSet::new()));

client_sessions.write().await.insert(session_id);
}
fn get_sessions(&self, client_id: ClientId) -> Vec<SessionId> {
match self.get(&client_id) {
Some(sessions) => sessions.iter().map(|session| session.clone()).collect(),

async fn get_sessions(&self, client_id: ClientId) -> Vec<SessionId> {
let clients_read = self.read().await;
match clients_read.get(&client_id) {
Some(sessions) => {
let client_sessions = sessions.read().await;

client_sessions.iter().cloned().collect()
}
None => vec![],
}
}
Expand All @@ -117,29 +139,34 @@ impl ModifySession for ClientToSessions {
mod tests {
use super::*;

#[test]
fn test_modify_session() {
#[tokio::test]
async fn test_modify_session() {
// Create a new ClientToSessions instance for testing
let client_to_sessions = ClientToSessions::default();

// Add a session
let client_id = "client1".to_string();
let session_id = "session1".to_string();
client_to_sessions.add_session(client_id.clone(), session_id.clone());
client_to_sessions
.add_session(client_id.clone(), session_id.clone())
.await;

// Get sessions for the client
let sessions = client_to_sessions.get_sessions(client_id.clone());
let sessions = client_to_sessions.get_sessions(client_id.clone()).await;
assert_eq!(sessions, vec![session_id.clone()]);

// Remove the session
client_to_sessions.remove_session(client_id.clone(), session_id.clone());
client_to_sessions
.remove_session(client_id.clone(), session_id.clone())
.await;

// Ensure the session is removed
let sessions = client_to_sessions.get_sessions(client_id.clone());
let sessions = client_to_sessions.get_sessions(client_id.clone()).await;
assert!(sessions.is_empty());

// Ensure the client is removed
let maybe_sessions = client_to_sessions.get(&client_id);
let client_to_sessions_read = client_to_sessions.read().await;
let maybe_sessions = client_to_sessions_read.get(&client_id);
assert!(maybe_sessions.is_none());
}
}
2 changes: 1 addition & 1 deletion server/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use hyper::{header, Method};
use axum::http::{header, Method};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tower_http::cors::{Any, CorsLayer};

Expand Down

0 comments on commit 5b560cd

Please sign in to comment.