diff --git a/engineioxide/src/str.rs b/engineioxide/src/str.rs index 8655f44d..beeea1ed 100644 --- a/engineioxide/src/str.rs +++ b/engineioxide/src/str.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use bytes::Bytes; /// A custom [`Bytes`] wrapper to efficiently store string packets @@ -46,6 +48,23 @@ impl From for Str { } } +impl From> for Str { + fn from(s: Cow<'static, str>) -> Self { + match s { + Cow::Borrowed(s) => Str::from(s), + Cow::Owned(s) => Str::from(s), + } + } +} +impl From<&Cow<'static, str>> for Str { + fn from(s: &Cow<'static, str>) -> Self { + match s { + Cow::Borrowed(s) => Str::from(*s), + Cow::Owned(s) => Str(Bytes::copy_from_slice(s.as_bytes())), + } + } +} + impl From for Bytes { fn from(s: Str) -> Self { s.0 diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index 97b82235..d95b6ac2 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -49,13 +49,13 @@ impl Client { fn sock_connect( &self, auth: Option, - ns_path: &str, + ns_path: Str, esocket: &Arc>>, ) -> Result<(), Error> { #[cfg(feature = "tracing")] tracing::debug!("auth: {:?}", auth); - if let Some(ns) = self.get_ns(ns_path) { + if let Some(ns) = self.get_ns(&ns_path) { let esocket = esocket.clone(); tokio::spawn(async move { if ns.connect(esocket.id, esocket.clone(), auth).await.is_ok() { @@ -246,7 +246,7 @@ impl EngineIoHandler for Client { if protocol == ProtocolVersion::V4 { #[cfg(feature = "tracing")] tracing::debug!("connecting to default namespace for v4"); - self.sock_connect(None, "/", &socket).unwrap(); + self.sock_connect(None, Str::from("/"), &socket).unwrap(); } if protocol == ProtocolVersion::V5 { @@ -299,7 +299,7 @@ impl EngineIoHandler for Client { let res: Result<(), Error> = match packet.inner { PacketData::Connect(auth) => self - .sock_connect(auth, &packet.ns, &socket) + .sock_connect(auth, packet.ns, &socket) .map_err(Into::into), PacketData::BinaryEvent(_, _, _) | PacketData::BinaryAck(_, _) => { // Cache-in the socket data until all the binary payloads are received diff --git a/socketioxide/src/extract/socket.rs b/socketioxide/src/extract/socket.rs index 1be53612..8dc89741 100644 --- a/socketioxide/src/extract/socket.rs +++ b/socketioxide/src/extract/socket.rs @@ -122,7 +122,7 @@ impl AckSender { return Err(e.with_value(data).into()); } }; - let ns = self.socket.ns(); + let ns = &self.socket.ns.path; let data = serde_json::to_value(data)?; let packet = if self.binary.is_empty() { Packet::ack(ns, data, ack_id) diff --git a/socketioxide/src/packet.rs b/socketioxide/src/packet.rs index 435b7184..f5e65507 100644 --- a/socketioxide/src/packet.rs +++ b/socketioxide/src/packet.rs @@ -18,70 +18,70 @@ pub struct Packet<'a> { /// The packet data pub inner: PacketData<'a>, /// The namespace the packet belongs to - pub ns: Cow<'a, str>, + pub ns: Str, } impl<'a> Packet<'a> { /// Send a connect packet with a default payload for v5 and no payload for v4 pub fn connect( - ns: &'a str, + ns: impl Into, #[allow(unused_variables)] sid: Sid, #[allow(unused_variables)] protocol: ProtocolVersion, ) -> Self { #[cfg(not(feature = "v4"))] { - Self::connect_v5(ns, sid) + Self::connect_v5(ns.into(), sid) } #[cfg(feature = "v4")] { match protocol { - ProtocolVersion::V4 => Self::connect_v4(ns), - ProtocolVersion::V5 => Self::connect_v5(ns, sid), + ProtocolVersion::V4 => Self::connect_v4(ns.into()), + ProtocolVersion::V5 => Self::connect_v5(ns.into(), sid), } } } /// Sends a connect packet without payload. #[cfg(feature = "v4")] - fn connect_v4(ns: &'a str) -> Self { + fn connect_v4(ns: Str) -> Self { Self { inner: PacketData::Connect(None), - ns: Cow::Borrowed(ns), + ns, } } /// Sends a connect packet with payload. - fn connect_v5(ns: &'a str, sid: Sid) -> Self { + fn connect_v5(ns: Str, sid: Sid) -> Self { let val = serde_json::to_string(&ConnectPacket { sid }).unwrap(); Self { inner: PacketData::Connect(Some(val)), - ns: Cow::Borrowed(ns), + ns, } } /// Create a disconnect packet for the given namespace - pub fn disconnect(ns: &'a str) -> Self { + pub fn disconnect(ns: impl Into) -> Self { Self { inner: PacketData::Disconnect, - ns: Cow::Borrowed(ns), + ns: ns.into(), } } } impl<'a> Packet<'a> { /// Create a connect error packet for the given namespace with a message - pub fn connect_error(ns: &'a str, message: &str) -> Self { + pub fn connect_error(ns: impl Into, message: &str) -> Self { let message = serde_json::to_string(message).unwrap(); let packet = format!(r#"{{"message":{}}}"#, message); Self { inner: PacketData::ConnectError(packet), - ns: Cow::Borrowed(ns), + ns: ns.into(), } } /// Create an event packet for the given namespace - pub fn event(ns: impl Into>, e: impl Into>, data: Value) -> Self { + pub fn event(ns: impl Into, e: impl Into>, data: Value) -> Self { Self { inner: PacketData::Event(e.into(), data, None), ns: ns.into(), @@ -90,7 +90,7 @@ impl<'a> Packet<'a> { /// Create a binary event packet for the given namespace pub fn bin_event( - ns: impl Into>, + ns: impl Into, e: impl Into>, data: Value, bin: Vec, @@ -105,20 +105,20 @@ impl<'a> Packet<'a> { } /// Create an ack packet for the given namespace - pub fn ack(ns: &'a str, data: Value, ack: i64) -> Self { + pub fn ack(ns: impl Into, data: Value, ack: i64) -> Self { Self { inner: PacketData::EventAck(data, ack), - ns: Cow::Borrowed(ns), + ns: ns.into(), } } /// Create a binary ack packet for the given namespace - pub fn bin_ack(ns: &'a str, data: Value, bin: Vec, ack: i64) -> Self { + pub fn bin_ack(ns: impl Into, data: Value, bin: Vec, ack: i64) -> Self { debug_assert!(!bin.is_empty()); let packet = BinaryPacket::outgoing(data, bin); Self { inner: PacketData::BinaryAck(packet, ack), - ns: Cow::Borrowed(ns), + ns: ns.into(), } } @@ -466,19 +466,19 @@ impl<'a> TryFrom for Packet<'a> { match chars.get(i) { Some(b',') => { i += 1; - break Cow::Owned(value[start_index..i - 1].to_string()); + break value.slice(start_index..i - 1); } // It maybe possible depending on clients that ns does not end with a comma // if it is the end of the packet // e.g `1/custom` None => { - break Cow::Owned(value[start_index..i].to_string()); + break value.slice(start_index..i); } Some(_) => i += 1, } } } else { - Cow::Borrowed("/") + Str::from("/") }; let start_index = i; diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index 35ba2d99..7743ac01 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -319,7 +319,7 @@ impl Socket { } }; - let ns = self.ns(); + let ns = &self.ns.path; let data = serde_json::to_value(data)?; permit.send(Packet::event(ns, event.into(), data)); Ok(()) @@ -392,8 +392,9 @@ impl Socket { return Err(e.with_value(data).into()); } }; + let ns = &self.ns.path; let data = serde_json::to_value(data)?; - let packet = Packet::event(self.ns(), event.into(), data); + let packet = Packet::event(ns, event.into(), data); let rx = self.send_with_ack_permit(packet, permit); let stream = AckInnerStream::send(rx, self.get_io().config().ack_timeout, self.id); Ok(AckStream::::from(stream)) diff --git a/socketioxide/tests/connect.rs b/socketioxide/tests/connect.rs index ba77d21b..b3978c37 100644 --- a/socketioxide/tests/connect.rs +++ b/socketioxide/tests/connect.rs @@ -7,7 +7,11 @@ use socketioxide::{ }; use tokio::sync::mpsc; -fn create_msg(ns: &str, event: &str, data: impl Into) -> engineioxide::Packet { +fn create_msg( + ns: &'static str, + event: &str, + data: impl Into, +) -> engineioxide::Packet { let packet: String = Packet::event(ns, event, data.into()).into(); Message(packet.into()) } diff --git a/socketioxide/tests/extractors.rs b/socketioxide/tests/extractors.rs index e9b38116..296d5fdd 100644 --- a/socketioxide/tests/extractors.rs +++ b/socketioxide/tests/extractors.rs @@ -25,7 +25,7 @@ async fn timeout_rcv_err(srx: &mut tokio::sync::mpsc::Receiv .unwrap_err(); } -fn create_msg(ns: &str, event: &str, data: impl Into) -> EioPacket { +fn create_msg(ns: &'static str, event: &str, data: impl Into) -> EioPacket { let packet: String = Packet::event(ns, event, data.into()).into(); EioPacket::Message(packet.into()) }