Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional improvements to integrate pubsub with Parity #110

Merged
merged 4 commits into from
Apr 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/types/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use serde_json::value::from_value;
use super::{Value, Error};

/// Request parameters
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone)]
pub enum Params {
/// Array of values
Array(Vec<Value>),
Expand Down
6 changes: 3 additions & 3 deletions core/src/types/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde_json::value::from_value;
use super::{Id, Value, Error, ErrorCode, Version};

/// Successful response
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct Success {
/// Protocol version
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -17,7 +17,7 @@ pub struct Success {
}

/// Unsuccessful response
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct Failure {
/// Protocol Version
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -29,7 +29,7 @@ pub struct Failure {
}

/// Represents output - failure or success
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone)]
pub enum Output {
/// Success
Success(Success),
Expand Down
2 changes: 1 addition & 1 deletion macros/examples/pubsub-macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ fn main() {
{
let subscribers = active_subscriptions.read().unwrap();
for sink in subscribers.values() {
let _ = sink.send("Hello World!".into()).wait();
let _ = sink.notify("Hello World!".into()).wait();
}
}
thread::sleep(::std::time::Duration::from_secs(1));
Expand Down
60 changes: 57 additions & 3 deletions macros/src/pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use jsonrpc_pubsub as pubsub;
use serde;
use util::to_value;

use self::core::futures::{self, Sink as FuturesSink};

pub use self::pubsub::SubscriptionId;

pub struct Subscriber<T> {
Expand All @@ -29,6 +31,7 @@ impl<T> Subscriber<T> {
Ok(Sink {
id: id,
sink: sink,
buffered: None,
_data: PhantomData,
})
}
Expand All @@ -37,16 +40,67 @@ impl<T> Subscriber<T> {
pub struct Sink<T> {
sink: pubsub::Sink,
id: SubscriptionId,
buffered: Option<core::Params>,
_data: PhantomData<T>,
}

impl<T: serde::Serialize> Sink<T> {
pub fn send(&self, val: T) -> pubsub::SinkResult {
pub fn notify(&self, val: T) -> pubsub::SinkResult {
self.sink.notify(self.val_to_params(val))
}

fn val_to_params(&self, val: T) -> core::Params {
let id = self.id.clone().into();
let val = to_value(val);
self.sink.send(core::Params::Map(vec![

core::Params::Map(vec![
("subscription".to_owned(), id),
("result".to_owned(), val),
].into_iter().collect()))
].into_iter().collect())
}

fn poll(&mut self) -> futures::Poll<(), pubsub::TransportError> {
if let Some(item) = self.buffered.take() {
let result = self.sink.start_send(item)?;
if let futures::AsyncSink::NotReady(item) = result {
self.buffered = Some(item);
}
}

if self.buffered.is_some() {
Ok(futures::Async::NotReady)
} else {
Ok(futures::Async::Ready(()))
}
}
}

impl<T: serde::Serialize> futures::sink::Sink for Sink<T> {
type SinkItem = T;
type SinkError = pubsub::TransportError;

fn start_send(&mut self, item: Self::SinkItem) -> futures::StartSend<Self::SinkItem, Self::SinkError> {
// Make sure to always try to process the buffered entry.
// Since we're just a proxy to real `Sink` we don't need
// to schedule a `Task` wakeup. It will be done downstream.
if self.poll()?.is_not_ready() {
return Ok(futures::AsyncSink::NotReady(item));
}

let val = self.val_to_params(item);
self.buffered = Some(val);
self.poll()?;

Ok(futures::AsyncSink::Ready)
}

fn poll_complete(&mut self) -> futures::Poll<(), Self::SinkError> {
self.poll()?;
self.sink.poll_complete()
}

fn close(&mut self) -> futures::Poll<(), Self::SinkError> {
self.poll()?;
self.sink.close()
}
}
2 changes: 1 addition & 1 deletion pubsub/examples/pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn main() {
thread::spawn(move || {
loop {
thread::sleep(time::Duration::from_millis(100));
match sink.send(Params::Array(vec![Value::Number(10.into())])).wait() {
match sink.notify(Params::Array(vec![Value::Number(10.into())])).wait() {
Ok(_) => {},
Err(_) => {
println!("Subscription has ended, finishing.");
Expand Down
5 changes: 1 addition & 4 deletions pubsub/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,4 @@ mod types;

pub use self::handler::{PubSubHandler, SubscribeRpcMethod, UnsubscribeRpcMethod};
pub use self::subscription::{Session, Sink, Subscriber, new_subscription};
pub use self::types::{PubSubMetadata, SubscriptionId};

/// Subscription send result.
pub type SinkResult = core::futures::sink::Send<types::TransportSender>;
pub use self::types::{PubSubMetadata, SubscriptionId, TransportError, SinkResult};
39 changes: 34 additions & 5 deletions pubsub/src/subscription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use std::sync::Arc;
use parking_lot::Mutex;

use core;
use core::futures::{self, sink, future, Sink as FuturesSink, Future, BoxFuture};
use core::futures::{self, future, Sink as FuturesSink, Future, BoxFuture};
use core::futures::sync::oneshot;

use handler::{SubscribeRpcMethod, UnsubscribeRpcMethod};
use types::{PubSubMetadata, SubscriptionId, TransportSender};
use types::{PubSubMetadata, SubscriptionId, TransportSender, TransportError, SinkResult};

/// RPC client session
/// Keeps track of active subscriptions and unsubscribes from them upon dropping.
Expand Down Expand Up @@ -80,21 +80,50 @@ impl Drop for Session {
}

/// A handle to send notifications directly to subscribed client.
#[derive(Debug, Clone)]
pub struct Sink {
notification: String,
transport: TransportSender
}

impl Sink {
/// Sends a notification to a client.
pub fn send(&self, val: core::Params) -> sink::Send<TransportSender> {
pub fn notify(&self, val: core::Params) -> SinkResult {
let val = self.params_to_string(val);
self.transport.clone().send(val.0)
}

fn params_to_string(&self, val: core::Params) -> (String, core::Params) {
let notification = core::Notification {
jsonrpc: Some(core::Version::V2),
method: self.notification.clone(),
params: Some(val),
};
(
core::to_string(&notification).expect("Notification serialization never fails."),
notification.params.expect("Always Some"),
)
}
}

impl FuturesSink for Sink {
type SinkItem = core::Params;
type SinkError = TransportError;

fn start_send(&mut self, item: Self::SinkItem) -> futures::StartSend<Self::SinkItem, Self::SinkError> {
let (val, params) = self.params_to_string(item);
self.transport.start_send(val).map(|result| match result {
futures::AsyncSink::Ready => futures::AsyncSink::Ready,
futures::AsyncSink::NotReady(_) => futures::AsyncSink::NotReady(params),
})
}

fn poll_complete(&mut self) -> futures::Poll<(), Self::SinkError> {
self.transport.poll_complete()
}

self.transport.clone().send(core::to_string(&notification).expect("Notification serialization never fails."))
fn close(&mut self) -> futures::Poll<(), Self::SinkError> {
self.transport.close()
}
}

Expand Down Expand Up @@ -324,7 +353,7 @@ mod tests {
};

// when
sink.send(core::Params::Array(vec![core::Value::Number(10.into())])).wait().unwrap();
sink.notify(core::Params::Array(vec![core::Value::Number(10.into())])).wait().unwrap();

// then
assert_eq!(
Expand Down
4 changes: 4 additions & 0 deletions pubsub/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ use subscription::Session;

/// Raw transport sink for specific client.
pub type TransportSender = mpsc::Sender<String>;
/// Raw transport error.
pub type TransportError = mpsc::SendError<String>;
/// Subscription send result.
pub type SinkResult = core::futures::sink::Send<TransportSender>;

/// Metadata extension for pub-sub method handling.
pub trait PubSubMetadata: core::Metadata {
Expand Down
3 changes: 3 additions & 0 deletions ws/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ use core;
use ws;

use session;
use Origin;

/// Request context
pub struct RequestContext {
/// Session id
pub session_id: session::SessionId,
/// Request Origin
pub origin: Option<Origin>,
/// Direct channel to send messages to a client.
pub out: ws::Sender,
}
Expand Down
12 changes: 9 additions & 3 deletions ws/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,11 @@ impl<M: core::Metadata, S: core::Middleware<M>> Drop for Session<M, S> {
}

impl<M: core::Metadata, S: core::Middleware<M>> Session<M, S> {
fn verify_origin(&self, req: &ws::Request) -> Option<ws::Response> {
let origin = req.header("origin").map(|x| &x[..]);
fn read_origin<'a>(&self, req: &'a ws::Request) -> Option<&'a [u8]> {
req.header("origin").map(|x| &x[..])
}

fn verify_origin(&self, origin: Option<&[u8]>) -> Option<ws::Response> {
if !header_is_allowed(&self.allowed_origins, origin) {
warn!(target: "signer", "Blocked connection to Signer API from untrusted origin: {:?}", origin);
Some(forbidden(
Expand Down Expand Up @@ -131,9 +134,10 @@ impl<M: core::Metadata, S: core::Middleware<M>> ws::Handler for Session<M, S> {
MiddlewareAction::Proceed
};

let origin = self.read_origin(req);
if action.should_verify_origin() {
// Verify request origin.
if let Some(response) = self.verify_origin(req) {
if let Some(response) = self.verify_origin(origin) {
return Ok(response);
}
}
Expand All @@ -145,6 +149,7 @@ impl<M: core::Metadata, S: core::Middleware<M>> ws::Handler for Session<M, S> {
}
}

self.context.origin = origin.and_then(|origin| ::std::str::from_utf8(origin).ok()).map(Into::into);
self.metadata = self.meta_extractor.extract(&self.context);

match action {
Expand Down Expand Up @@ -217,6 +222,7 @@ impl<M: core::Metadata, S: core::Middleware<M>> ws::Factory for Factory<M, S> {
Session {
context: metadata::RequestContext {
session_id: self.session_id,
origin: None,
out: sender,
},
handler: self.handler.clone(),
Expand Down