diff --git a/core/src/types/params.rs b/core/src/types/params.rs index cbd9b819b..6a1527289 100644 --- a/core/src/types/params.rs +++ b/core/src/types/params.rs @@ -10,7 +10,7 @@ use serde_json::value::from_value; use super::{Value, Error}; /// Request parameters -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub enum Params { /// Array of values Array(Vec), diff --git a/core/src/types/response.rs b/core/src/types/response.rs index 8a15a5ffe..d243ca36e 100644 --- a/core/src/types/response.rs +++ b/core/src/types/response.rs @@ -5,7 +5,7 @@ use serde_json::value::from_value; use super::{Id, Value, Error, ErrorCode, Version}; /// Successful response -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct Success { /// Protocol version #[serde(skip_serializing_if = "Option::is_none")] @@ -17,7 +17,7 @@ pub struct Success { } /// Unsuccessful response -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct Failure { /// Protocol Version #[serde(skip_serializing_if = "Option::is_none")] @@ -29,7 +29,7 @@ pub struct Failure { } /// Represents output - failure or success -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub enum Output { /// Success Success(Success), diff --git a/macros/examples/pubsub-macros.rs b/macros/examples/pubsub-macros.rs index 4de1e841a..4fcedccef 100644 --- a/macros/examples/pubsub-macros.rs +++ b/macros/examples/pubsub-macros.rs @@ -99,7 +99,7 @@ fn main() { { let subscribers = active_subscriptions.read().unwrap(); for sink in subscribers.values() { - let _ = sink.send("Hello World!".into()).wait(); + let _ = sink.notify("Hello World!".into()).wait(); } } thread::sleep(::std::time::Duration::from_secs(1)); diff --git a/macros/src/pubsub.rs b/macros/src/pubsub.rs index c4c1e62b2..663d5e3b2 100644 --- a/macros/src/pubsub.rs +++ b/macros/src/pubsub.rs @@ -5,6 +5,8 @@ use jsonrpc_pubsub as pubsub; use serde; use util::to_value; +use self::core::futures::{self, Sink as FuturesSink}; + pub use self::pubsub::SubscriptionId; pub struct Subscriber { @@ -29,6 +31,7 @@ impl Subscriber { Ok(Sink { id: id, sink: sink, + buffered: None, _data: PhantomData, }) } @@ -37,16 +40,67 @@ impl Subscriber { pub struct Sink { sink: pubsub::Sink, id: SubscriptionId, + buffered: Option, _data: PhantomData, } impl Sink { - pub fn send(&self, val: T) -> pubsub::SinkResult { + pub fn notify(&self, val: T) -> pubsub::SinkResult { + self.sink.notify(self.val_to_params(val)) + } + + fn val_to_params(&self, val: T) -> core::Params { let id = self.id.clone().into(); let val = to_value(val); - self.sink.send(core::Params::Map(vec![ + + core::Params::Map(vec![ ("subscription".to_owned(), id), ("result".to_owned(), val), - ].into_iter().collect())) + ].into_iter().collect()) + } + + fn poll(&mut self) -> futures::Poll<(), pubsub::TransportError> { + if let Some(item) = self.buffered.take() { + let result = self.sink.start_send(item)?; + if let futures::AsyncSink::NotReady(item) = result { + self.buffered = Some(item); + } + } + + if self.buffered.is_some() { + Ok(futures::Async::NotReady) + } else { + Ok(futures::Async::Ready(())) + } + } +} + +impl futures::sink::Sink for Sink { + type SinkItem = T; + type SinkError = pubsub::TransportError; + + fn start_send(&mut self, item: Self::SinkItem) -> futures::StartSend { + // Make sure to always try to process the buffered entry. + // Since we're just a proxy to real `Sink` we don't need + // to schedule a `Task` wakeup. It will be done downstream. + if self.poll()?.is_not_ready() { + return Ok(futures::AsyncSink::NotReady(item)); + } + + let val = self.val_to_params(item); + self.buffered = Some(val); + self.poll()?; + + Ok(futures::AsyncSink::Ready) + } + + fn poll_complete(&mut self) -> futures::Poll<(), Self::SinkError> { + self.poll()?; + self.sink.poll_complete() + } + + fn close(&mut self) -> futures::Poll<(), Self::SinkError> { + self.poll()?; + self.sink.close() } } diff --git a/pubsub/examples/pubsub.rs b/pubsub/examples/pubsub.rs index 6084992dc..1353be8b4 100644 --- a/pubsub/examples/pubsub.rs +++ b/pubsub/examples/pubsub.rs @@ -55,7 +55,7 @@ fn main() { thread::spawn(move || { loop { thread::sleep(time::Duration::from_millis(100)); - match sink.send(Params::Array(vec![Value::Number(10.into())])).wait() { + match sink.notify(Params::Array(vec![Value::Number(10.into())])).wait() { Ok(_) => {}, Err(_) => { println!("Subscription has ended, finishing."); diff --git a/pubsub/src/lib.rs b/pubsub/src/lib.rs index 2a01b84ce..ab9f297fb 100644 --- a/pubsub/src/lib.rs +++ b/pubsub/src/lib.rs @@ -15,7 +15,4 @@ mod types; pub use self::handler::{PubSubHandler, SubscribeRpcMethod, UnsubscribeRpcMethod}; pub use self::subscription::{Session, Sink, Subscriber, new_subscription}; -pub use self::types::{PubSubMetadata, SubscriptionId}; - -/// Subscription send result. -pub type SinkResult = core::futures::sink::Send; +pub use self::types::{PubSubMetadata, SubscriptionId, TransportError, SinkResult}; diff --git a/pubsub/src/subscription.rs b/pubsub/src/subscription.rs index 48e7d3480..1fd7f7069 100644 --- a/pubsub/src/subscription.rs +++ b/pubsub/src/subscription.rs @@ -6,11 +6,11 @@ use std::sync::Arc; use parking_lot::Mutex; use core; -use core::futures::{self, sink, future, Sink as FuturesSink, Future, BoxFuture}; +use core::futures::{self, future, Sink as FuturesSink, Future, BoxFuture}; use core::futures::sync::oneshot; use handler::{SubscribeRpcMethod, UnsubscribeRpcMethod}; -use types::{PubSubMetadata, SubscriptionId, TransportSender}; +use types::{PubSubMetadata, SubscriptionId, TransportSender, TransportError, SinkResult}; /// RPC client session /// Keeps track of active subscriptions and unsubscribes from them upon dropping. @@ -80,6 +80,7 @@ impl Drop for Session { } /// A handle to send notifications directly to subscribed client. +#[derive(Debug, Clone)] pub struct Sink { notification: String, transport: TransportSender @@ -87,14 +88,42 @@ pub struct Sink { impl Sink { /// Sends a notification to a client. - pub fn send(&self, val: core::Params) -> sink::Send { + pub fn notify(&self, val: core::Params) -> SinkResult { + let val = self.params_to_string(val); + self.transport.clone().send(val.0) + } + + fn params_to_string(&self, val: core::Params) -> (String, core::Params) { let notification = core::Notification { jsonrpc: Some(core::Version::V2), method: self.notification.clone(), params: Some(val), }; + ( + core::to_string(¬ification).expect("Notification serialization never fails."), + notification.params.expect("Always Some"), + ) + } +} + +impl FuturesSink for Sink { + type SinkItem = core::Params; + type SinkError = TransportError; + + fn start_send(&mut self, item: Self::SinkItem) -> futures::StartSend { + let (val, params) = self.params_to_string(item); + self.transport.start_send(val).map(|result| match result { + futures::AsyncSink::Ready => futures::AsyncSink::Ready, + futures::AsyncSink::NotReady(_) => futures::AsyncSink::NotReady(params), + }) + } + + fn poll_complete(&mut self) -> futures::Poll<(), Self::SinkError> { + self.transport.poll_complete() + } - self.transport.clone().send(core::to_string(¬ification).expect("Notification serialization never fails.")) + fn close(&mut self) -> futures::Poll<(), Self::SinkError> { + self.transport.close() } } @@ -324,7 +353,7 @@ mod tests { }; // when - sink.send(core::Params::Array(vec![core::Value::Number(10.into())])).wait().unwrap(); + sink.notify(core::Params::Array(vec![core::Value::Number(10.into())])).wait().unwrap(); // then assert_eq!( diff --git a/pubsub/src/types.rs b/pubsub/src/types.rs index 9c301b6ba..67b5d6df7 100644 --- a/pubsub/src/types.rs +++ b/pubsub/src/types.rs @@ -6,6 +6,10 @@ use subscription::Session; /// Raw transport sink for specific client. pub type TransportSender = mpsc::Sender; +/// Raw transport error. +pub type TransportError = mpsc::SendError; +/// Subscription send result. +pub type SinkResult = core::futures::sink::Send; /// Metadata extension for pub-sub method handling. pub trait PubSubMetadata: core::Metadata { diff --git a/ws/src/metadata.rs b/ws/src/metadata.rs index de41c654b..7cfdf2848 100644 --- a/ws/src/metadata.rs +++ b/ws/src/metadata.rs @@ -2,11 +2,14 @@ use core; use ws; use session; +use Origin; /// Request context pub struct RequestContext { /// Session id pub session_id: session::SessionId, + /// Request Origin + pub origin: Option, /// Direct channel to send messages to a client. pub out: ws::Sender, } diff --git a/ws/src/session.rs b/ws/src/session.rs index 7b91b1d9c..d75f5f544 100644 --- a/ws/src/session.rs +++ b/ws/src/session.rs @@ -95,8 +95,11 @@ impl> Drop for Session { } impl> Session { - fn verify_origin(&self, req: &ws::Request) -> Option { - let origin = req.header("origin").map(|x| &x[..]); + fn read_origin<'a>(&self, req: &'a ws::Request) -> Option<&'a [u8]> { + req.header("origin").map(|x| &x[..]) + } + + fn verify_origin(&self, origin: Option<&[u8]>) -> Option { if !header_is_allowed(&self.allowed_origins, origin) { warn!(target: "signer", "Blocked connection to Signer API from untrusted origin: {:?}", origin); Some(forbidden( @@ -131,9 +134,10 @@ impl> ws::Handler for Session { MiddlewareAction::Proceed }; + let origin = self.read_origin(req); if action.should_verify_origin() { // Verify request origin. - if let Some(response) = self.verify_origin(req) { + if let Some(response) = self.verify_origin(origin) { return Ok(response); } } @@ -145,6 +149,7 @@ impl> ws::Handler for Session { } } + self.context.origin = origin.and_then(|origin| ::std::str::from_utf8(origin).ok()).map(Into::into); self.metadata = self.meta_extractor.extract(&self.context); match action { @@ -217,6 +222,7 @@ impl> ws::Factory for Factory { Session { context: metadata::RequestContext { session_id: self.session_id, + origin: None, out: sender, }, handler: self.handler.clone(),