diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 1766b65d..d4e405fd 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -1,6 +1,8 @@ #![deny(clippy::await_holding_lock)] + use std::{ - collections::HashMap, + collections::{hash_map::Entry, HashMap}, + fmt::Debug, io::BufRead, sync::{Arc, RwLock}, }; @@ -21,7 +23,6 @@ use bytes::Buf; use futures::{stream::SplitStream, SinkExt, StreamExt, TryStreamExt}; use http::{Request, Response, StatusCode}; use hyper::upgrade::Upgraded; -use std::fmt::Debug; use tokio_tungstenite::{ tungstenite::{protocol::Role, Message}, WebSocketStream, @@ -71,24 +72,31 @@ where { let engine = self.clone(); let close_fn = Box::new(move |sid: Sid| engine.close_session(sid)); - let sid = generate_sid(); - let socket = Socket::new( - sid, - ConnectionType::Http, - &self.config, - SocketReq::from(req.into_parts().0), - close_fn, - ); - let socket = Arc::new(socket); - { - self.sockets.write().unwrap().insert(sid, socket.clone()); - } + + let mut lock = self.sockets.write().unwrap(); + let socket = loop { + let sid = generate_sid(); + if let Entry::Vacant(entry) = lock.entry(sid) { + let socket = Socket::new( + sid, + ConnectionType::Http, + &self.config, + SocketReq::from(req.into_parts().0), + close_fn, + ); + let socket = Arc::new(socket); + entry.insert(socket.clone()); + break socket; + } + }; + drop(lock); + socket .clone() .spawn_heartbeat(self.config.ping_interval, self.config.ping_timeout); self.handler.on_connect(&socket); - let packet = OpenPacket::new(TransportType::Polling, sid, &self.config); + let packet = OpenPacket::new(TransportType::Polling, socket.sid, &self.config); let packet: String = Packet::Open(packet).try_into()?; http_response(StatusCode::OK, packet).map_err(Error::Http) } @@ -255,23 +263,31 @@ where } } } else { - let sid = generate_sid(); let engine = self.clone(); let close_fn = Box::new(move |sid: Sid| engine.close_session(sid)); - let socket = Socket::new( - sid, - ConnectionType::WebSocket, - &self.config, - req_data, - close_fn, - ); - let socket = Arc::new(socket); - { - self.sockets.write().unwrap().insert(sid, socket.clone()); - } - debug!("[sid={sid}] new websocket connection"); + + let socket = { + let mut lock = self.sockets.write().unwrap(); + let socket = loop { + let sid = generate_sid(); + if let Entry::Vacant(entry) = lock.entry(sid) { + let socket = Socket::new( + sid, + ConnectionType::WebSocket, + &self.config, + req_data, + close_fn, + ); + let socket = Arc::new(socket); + entry.insert(socket.clone()); + break socket; + } + }; + socket + }; + debug!("[sid={}] new websocket connection", socket.sid); let mut ws = ws_init().await; - self.ws_init_handshake(sid, &mut ws).await?; + self.ws_init_handshake(socket.sid, &mut ws).await?; socket .clone() .spawn_heartbeat(self.config.ping_interval, self.config.ping_timeout);