Skip to content

Commit

Permalink
WT WASM receiving loop
Browse files Browse the repository at this point in the history
  • Loading branch information
aecsocket committed Oct 13, 2023
1 parent 8c5ea67 commit e7522e2
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 35 deletions.
2 changes: 1 addition & 1 deletion aeronet_wt_native/src/client/front.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl<C2S, S2C> WebTransportClient<C2S, S2C> {
/// Requests the client to connect to a given URL.
///
/// If the client is [connected], this request has no effect.
///
///
/// [connected]: aeronet::ClientTransport::connected
pub fn connect(&self, url: impl Into<String>) {
let _ = self.send.try_send(Request::Connect { url: url.into() });
Expand Down
2 changes: 1 addition & 1 deletion aeronet_wt_native/src/server/back.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use aeronet::{ClientId, Message, SessionError, TryFromBytes, TryIntoBytes};
use tokio::sync::{broadcast, mpsc};
use tracing::{debug, debug_span, Instrument};
use wtransport::{
endpoint::{IncomingSession, endpoint_side::Server},
endpoint::{endpoint_side::Server, IncomingSession},
Connection, Endpoint, ServerConfig,
};

Expand Down
2 changes: 1 addition & 1 deletion aeronet_wt_native/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ where
/// Allows converting a [`Message`] into a [`StreamMessage`].
///
/// This is automatically implemented for all types.
///
///
/// [`Message`]: aeronet::Message
pub trait OnStream<S>: Sized {
/// Converts this into a [`StreamMessage`] by providing the stream along which the
Expand Down
2 changes: 2 additions & 0 deletions aeronet_wt_wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ features = [
"Blob",
"Url",
"Worker",
"ReadableStreamDefaultReader",
"MessageEvent",
]

[dev-dependencies]
Expand Down
147 changes: 127 additions & 20 deletions aeronet_wt_wasm/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,61 +1,168 @@
use js_sys::Array;
use wasm_bindgen::JsValue;
use std::{marker::PhantomData, sync::mpsc};

use aeronet::{ClientEvent, ClientTransport, Message, RecvError, SessionError, TryFromBytes};
use js_sys::{Array, Reflect, Uint8Array};
use wasm_bindgen::{prelude::Closure, JsCast, JsValue};
use wasm_bindgen_futures::JsFuture;
use web_sys::{Blob, Url, Worker};
use web_sys::{Blob, MessageEvent, ReadableStreamDefaultReader, Url, Worker};

use crate::{bindings::WebTransport, WebTransportErrorUnused, WebTransportOptions};

use crate::{bindings::WebTransport, WebTransportError, WebTransportOptions};
/*
implementation notes:
* use our WebTransport bindings for wasm
* spawn a web worker which:
* receives a WebTransportDatagramDuplexStream.readable from rust
* gets the readable.getReader() and saves it into `reader`
* in an infinite loop, reads from `reader`
* posts the reader data back to rust
* this is because we can't use getReader().read() from rust yet
* on the rust side:
* uhhhh idk?
*/

pub struct WebTransportClient {
pub struct WebTransportClient<C2S, S2C> {
transport: WebTransport,
worker: Worker,
recv_s2c: mpsc::Receiver<ClientEvent<S2C>>,
_phantom_c2s: PhantomData<C2S>,
}

#[derive(Debug, Clone, thiserror::Error)]
pub enum WebTransportClientError {
pub enum WebTransportError {
#[error("failed to create transport")]
CreateTransport(#[source] WebTransportError),
CreateTransport,
#[error("failed to create worker")]
CreateWorker,
}

impl WebTransportClient {
pub async fn new(
impl<C2S, S2C> WebTransportClient<C2S, S2C>
where
C2S: Message,
S2C: Message + TryFromBytes,
{
pub fn new(
url: impl AsRef<str>,
options: WebTransportOptions,
) -> Result<Self, WebTransportClientError> {
) -> Result<Self, WebTransportError> {
let url = url.as_ref();
let worker = create_worker().map_err(|_| WebTransportClientError::CreateWorker)?;
let transport = create_transport(url, options)
.map_err(|err| WebTransportClientError::CreateTransport(err))?;
let _ = JsFuture::from(transport.ready()).await;
let transport =
create_transport(url, options).map_err(|_| WebTransportError::CreateTransport)?;
let worker = create_worker().map_err(|_| WebTransportError::CreateWorker)?;

let (send_s2c, recv_s2c) = mpsc::channel::<ClientEvent<S2C>>();
let reader = ReadableStreamDefaultReader::from(JsValue::from(
transport.datagrams().readable().get_reader(),
));

wasm_bindgen_futures::spawn_local(async move {
let _ = match Self::recv_from_reader(reader).await {
Ok(msg) => send_s2c.send(ClientEvent::Recv { msg }),
Err(reason) => send_s2c.send(ClientEvent::Disconnected { reason }),
};
});

Ok(Self {
transport,
worker,
recv_s2c,
_phantom_c2s: PhantomData::default(),
})
}

async fn recv_from_reader(
reader: ReadableStreamDefaultReader,
) -> Result<S2C, SessionError> {
let (payload, done) = JsFuture::from(reader.read())
.await
.and_then(|js| {
let value =
Uint8Array::from(Reflect::get(&js, &JsValue::from("value")).unwrap());
let done = Reflect::get(&js, &JsValue::from("done"))
.unwrap()
.as_bool()
.unwrap();
Ok((value, done))
})
.unwrap();
if done {
// todo turn this into its own error type
return Err(SessionError::Transport(anyhow::anyhow!("closed")));
}

Ok(Self { transport, worker })
let payload = payload.to_vec();
let msg = S2C::try_from_bytes(payload.as_slice())
.map_err(|err| SessionError::Transport(err.into()))?;
Ok(msg)
}
}

impl Drop for WebTransportClient {
impl<C2S, S2C> Drop for WebTransportClient<C2S, S2C> {
fn drop(&mut self) {
self.transport.close();
self.worker.terminate();
}
}

impl<C2S, S2C> ClientTransport<C2S, S2C> for WebTransportClient<C2S, S2C>
where
C2S: Message,
S2C: Message,
{
type Info = ();

fn recv(&mut self) -> Result<ClientEvent<S2C>, RecvError> {
self.recv_s2c.try_recv().map_err(|err| match err {
mpsc::TryRecvError::Empty => RecvError::Empty,
_ => RecvError::Closed,
})
}

fn send(&mut self, msg: impl Into<C2S>) {
todo!()
}

fn info(&self) -> Option<Self::Info> {
todo!()
}

fn connected(&self) -> bool {
todo!()
}
}

fn create_transport(
url: &str,
options: WebTransportOptions,
) -> Result<WebTransport, WebTransportError> {
) -> Result<WebTransport, WebTransportErrorUnused> {
let options = options.as_js();
WebTransport::new_with_options(url, &options).map_err(|js| WebTransportError::from_js(js))
WebTransport::new_with_options(url, &options).map_err(|js| WebTransportErrorUnused::from_js(js))
}

const WORKER_SCRIPT: &str = "
function wait(ms) {
let reader = null;
function sleep(ms) {
return new Promise(res => setTimeout(res, ms));
}
self.onmessage = function(event) {
if (event.data) {
reader = event.data.getReader();
}
};
async function read() {
while (true) {
await wait(100);
if (reader) {
const { value, done } = await reader.read();
if (done) {
break;
}
self.postMessage(value);
} else {
await sleep(100);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions aeronet_wt_wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ mod wrappers;

pub use client::WebTransportClient;
pub use wrappers::{
CongestionControl, ServerCertificateHash, ServerCertificateHashAlgorithm, WebTransportError,
WebTransportErrorSource, WebTransportOptions,
CongestionControl, ServerCertificateHash, ServerCertificateHashAlgorithm,
WebTransportErrorSource, WebTransportErrorUnused, WebTransportOptions,
};
20 changes: 10 additions & 10 deletions aeronet_wt_wasm/src/wrappers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// https://developer.mozilla.org/en-US/docs/Web/API/WebTransport/WebTransport

use js_sys::{Array, Object, Reflect, Uint8Array};
use wasm_bindgen::JsValue;

type JsWebTransportOptions = crate::bindings::WebTransportOptions;
Expand Down Expand Up @@ -54,16 +55,16 @@ pub struct ServerCertificateHash {
}

impl ServerCertificateHash {
pub(crate) fn as_js(&self) -> js_sys::Object {
let res = js_sys::Object::new();
pub(crate) fn as_js(&self) -> Object {
let res = Object::new();

let algorithm = match self.algorithm {
ServerCertificateHashAlgorithm::Sha256 => "sha-256",
};
let _ = js_sys::Reflect::set(&res, &JsValue::from("algorithm"), &JsValue::from(algorithm));
let _ = Reflect::set(&res, &JsValue::from("algorithm"), &JsValue::from(algorithm));

let value = js_sys::Uint8Array::from(self.value.as_slice());
let _ = js_sys::Reflect::set(&res, &JsValue::from("value"), &value);
let value = Uint8Array::from(self.value.as_slice());
let _ = Reflect::set(&res, &JsValue::from("value"), &value);

res
}
Expand Down Expand Up @@ -114,9 +115,8 @@ impl WebTransportOptions {

res.require_unreliable(self.require_unreliable);

let hashes = js_sys::Array::new_with_length(
self.server_certificate_hashes.len().try_into().unwrap(),
);
let hashes =
Array::new_with_length(self.server_certificate_hashes.len().try_into().unwrap());
for (i, cert) in self.server_certificate_hashes.iter().enumerate() {
hashes.set(i.try_into().unwrap(), cert.as_js().into());
}
Expand All @@ -141,7 +141,7 @@ pub enum WebTransportErrorSource {
/// problems, or client-initiated abort operations.
#[derive(Debug, Clone, thiserror::Error)]
#[error("web transport {source} error (code {stream_error_code:?})")]
pub struct WebTransportError {
pub struct WebTransportErrorUnused {
/// Description of this error.
pub message: String,
/// Source of the error.
Expand All @@ -150,7 +150,7 @@ pub struct WebTransportError {
pub stream_error_code: Option<u8>,
}

impl WebTransportError {
impl WebTransportErrorUnused {
pub(crate) fn from_js(js: JsValue) -> Self {
let js = JsWebTransportError::from(js);

Expand Down

0 comments on commit e7522e2

Please sign in to comment.