Skip to content

Commit

Permalink
Merge pull request #123 from Totodore/fix-different-sid-for-socketio-ns
Browse files Browse the repository at this point in the history
fix(socketio/socket): create a different socket id for each ns
  • Loading branch information
Totodore authored Oct 21, 2023
2 parents 260231d + 72ba75b commit 0f9e0c1
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 38 deletions.
7 changes: 2 additions & 5 deletions engineioxide/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,15 +399,12 @@ impl<D> Socket<D>
where
D: Default + Send + Sync + 'static,
{
pub fn new_dummy(
sid: Sid,
close_fn: Box<dyn Fn(Sid, DisconnectReason) + Send + Sync>,
) -> Socket<D> {
pub fn new_dummy(close_fn: Box<dyn Fn(Sid, DisconnectReason) + Send + Sync>) -> Socket<D> {
let (internal_tx, internal_rx) = mpsc::channel(200);
let (heartbeat_tx, heartbeat_rx) = mpsc::channel(1);

Self {
id: sid,
id: Sid::new(),
protocol: ProtocolVersion::V4,
transport: AtomicU8::new(TransportType::Websocket as u8),

Expand Down
23 changes: 5 additions & 18 deletions socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,13 @@ impl<A: Adapter> Client<A> {
auth: Option<String>,
ns_path: String,
esocket: &Arc<engineioxide::Socket<SocketData>>,
) -> Result<(), serde_json::Error> {
) -> Result<(), Error> {
debug!("auth: {:?}", auth);
let sid = esocket.id;
if let Some(ns) = self.get_ns(&ns_path) {
let protocol: ProtocolVersion = esocket.protocol.into();
ns.connect(sid, esocket.clone(), auth, self.config.clone())?;

// cancel the connect timeout task for v5
#[cfg(feature = "v5")]
if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
tx.send(()).unwrap();
}

let connect_packet = Packet::connect(ns_path, sid, protocol);
if let Err(err) = esocket.emit(connect_packet.try_into()?) {
debug!("sending error during socket connection: {err:?}");
}
ns.connect(sid, esocket.clone(), auth, self.config.clone())?;
if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
tx.send(()).unwrap();
}
Expand Down Expand Up @@ -108,12 +98,9 @@ impl<A: Adapter> Client<A> {

/// Propagate a packet to a its target namespace
fn sock_propagate_packet(&self, packet: Packet, sid: Sid) -> Result<(), Error> {
if let Some(ns) = self.get_ns(&packet.ns) {
ns.recv(sid, packet.inner)
} else {
debug!("invalid namespace requested: {}", packet.ns);
Ok(())
}
self.get_ns(&packet.ns)
.ok_or(Error::InvalidNamespace(packet.ns))?
.recv(sid, packet.inner)
}

/// Spawn a task that will close the socket if it is not connected to a namespace
Expand Down
7 changes: 5 additions & 2 deletions socketioxide/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ pub enum Error {
InvalidEventName,

#[error("invalid namespace")]
InvalidNamespace,
InvalidNamespace(String),

#[error("cannot find socketio socket")]
SocketGone(Sid),

#[error("send error: {0}")]
SendError(#[from] SendError),

/// An engineio error
#[error("engineio error: {0}")]
EngineIoError(#[from] engineioxide::errors::Error),
Expand All @@ -44,7 +47,7 @@ impl From<&Error> for Option<EIoDisconnectReason> {
Error::SerializeError(_) | Error::InvalidPacketType | Error::InvalidEventName => {
Some(PacketParsingError)
}
Error::Adapter(_) | Error::InvalidNamespace => None,
Error::Adapter(_) | Error::InvalidNamespace(_) | Error::SendError(_) => None,
}
}
}
Expand Down
15 changes: 10 additions & 5 deletions socketioxide/src/ns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use crate::{
adapter::Adapter,
errors::Error,
handler::{BoxedNamespaceHandler, CallbackHandler},
packet::PacketData,
packet::{Packet, PacketData},
socket::Socket,
SocketIoConfig,
ProtocolVersion, SocketIoConfig,
};
use crate::{client::SocketData, errors::AdapterError};
use engineioxide::sid::Sid;
Expand Down Expand Up @@ -46,10 +46,15 @@ impl<A: Adapter> Namespace<A> {
esocket: Arc<engineioxide::Socket<SocketData>>,
auth: Option<String>,
config: Arc<SocketIoConfig>,
) -> Result<(), serde_json::Error> {
let socket: Arc<Socket<A>> = Socket::new(sid, self.clone(), esocket, config).into();
) -> Result<(), Error> {
let protocol: ProtocolVersion = esocket.protocol.into();
let socket: Arc<Socket<A>> = Socket::new(self.clone(), esocket, config).into();
self.sockets.write().unwrap().insert(sid, socket.clone());
self.handler.call(socket, auth)

socket.send(Packet::connect(self.path.clone(), socket.id, protocol))?;

self.handler.call(socket, auth)?;
Ok(())
}

/// Remove a socket from a namespace and propagate the event to the adapter
Expand Down
25 changes: 17 additions & 8 deletions socketioxide/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
time::Duration,
};

use engineioxide::{sid::Sid, socket::DisconnectReason as EIoDisconnectReason};
use engineioxide::{sid::Sid, socket::DisconnectReason as EIoDisconnectReason, ProtocolVersion};
use futures::{future::BoxFuture, Future};
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;
Expand Down Expand Up @@ -110,18 +110,22 @@ pub struct Socket<A: Adapter> {

impl<A: Adapter> Socket<A> {
pub(crate) fn new(
sid: Sid,
ns: Arc<Namespace<A>>,
esocket: Arc<engineioxide::Socket<SocketData>>,
config: Arc<SocketIoConfig>,
) -> Self {
let id = if esocket.protocol == ProtocolVersion::V3 {
esocket.id
} else {
Sid::new()
};
Self {
ns,
message_handlers: RwLock::new(HashMap::new()),
disconnect_handler: Mutex::new(None),
ack_message: Mutex::new(HashMap::new()),
ack_counter: AtomicI64::new(0),
id: sid,
id,
extensions: Extensions::new(),
config,
esocket,
Expand Down Expand Up @@ -583,11 +587,16 @@ impl<A: Adapter> Debug for Socket<A> {
impl<A: Adapter> Socket<A> {
pub fn new_dummy(sid: Sid, ns: Arc<Namespace<A>>) -> Socket<A> {
let close_fn = Box::new(move |_, _| ());
Socket::new(
sid,
Socket {
id: sid,
ns,
engineioxide::Socket::new_dummy(sid, close_fn).into(),
Arc::new(SocketIoConfig::default()),
)
ack_counter: AtomicI64::new(0),
ack_message: Mutex::new(HashMap::new()),
message_handlers: RwLock::new(HashMap::new()),
disconnect_handler: Mutex::new(None),
config: Arc::new(SocketIoConfig::default()),
extensions: Extensions::new(),
esocket: engineioxide::Socket::new_dummy(close_fn).into(),
}
}
}

0 comments on commit 0f9e0c1

Please sign in to comment.