Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: axum extractor #1353

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/json-rpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ serde_json = { workspace = true, features = ["std", "raw_value"] }
thiserror.workspace = true
tracing.workspace = true
alloy-sol-types.workspace = true

async-trait = { workspace = true, optional = true }
axum = { version = "0.7.6", features = ["json"], optional = true }

[features]
axum = ["dep:axum"]
2 changes: 2 additions & 0 deletions crates/json-rpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ pub use packet::{BorrowedResponsePacket, RequestPacket, ResponsePacket};
mod request;
pub use request::{PartiallySerializedRequest, Request, RequestMeta, SerializedRequest};

mod support;

mod response;
pub use response::{
BorrowedErrorPayload, BorrowedResponse, BorrowedResponsePayload, ErrorPayload, Response,
Expand Down
132 changes: 127 additions & 5 deletions crates/json-rpc/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use crate::{common::Id, RpcParam};
use crate::{common::Id, RpcObject, RpcParam};
use alloy_primitives::{keccak256, B256};
use serde::{de::DeserializeOwned, ser::SerializeMap, Deserialize, Serialize};
use serde::{
de::{DeserializeOwned, MapAccess},
ser::SerializeMap,
Deserialize, Serialize,
};
use serde_json::value::RawValue;
use std::borrow::Cow;
use std::{borrow::Cow, marker::PhantomData, mem::MaybeUninit};

/// `RequestMeta` contains the [`Id`] and method name of a request.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RequestMeta {
/// The method name.
pub method: Cow<'static, str>,
Expand Down Expand Up @@ -48,7 +52,7 @@ impl RequestMeta {
/// ### Note
///
/// The value of `method` should be known at compile time.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Request<Params> {
/// The request metadata (ID and method).
pub meta: RequestMeta,
Expand Down Expand Up @@ -182,6 +186,103 @@ where
}
}

impl<'de, Params> Deserialize<'de> for Request<Params>
where
Params: RpcObject,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Visitor<Params>(PhantomData<Params>);
impl<'de, Params> serde::de::Visitor<'de> for Visitor<Params>
where
Params: RpcObject,
{
type Value = Request<Params>;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
formatter,
"a JSON-RPC 2.0 request object with params of type {}",
std::any::type_name::<Params>()
)
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut id = None;
let mut params = None;
let mut method = None;
let mut jsonrpc = None;

while let Some(key) = map.next_key()? {
match key {
"id" => {
if id.is_some() {
return Err(serde::de::Error::duplicate_field("id"));
}
id = Some(map.next_value()?);
}
"params" => {
if params.is_some() {
return Err(serde::de::Error::duplicate_field("params"));
}
params = Some(map.next_value()?);
}
"method" => {
if method.is_some() {
return Err(serde::de::Error::duplicate_field("method"));
}
method = Some(map.next_value()?);
}
"jsonrpc" => {
let version: String = map.next_value()?;
if version != "2.0" {
return Err(serde::de::Error::custom(format!(
"unsupported JSON-RPC version: {}",
version
)));
}
jsonrpc = Some(());
}
other => {
return Err(serde::de::Error::unknown_field(
other,
&["id", "params", "method", "jsonrpc"],
));
}
}
}
if jsonrpc.is_none() {
return Err(serde::de::Error::missing_field("jsonrpc"));
}
if method.is_none() {
return Err(serde::de::Error::missing_field("method"));
}

if params.is_none() {
if std::mem::size_of::<Params>() == 0 {
// SAFETY: params is a ZST, so it's safe to fail to initialize it
unsafe { params = Some(MaybeUninit::<Params>::zeroed().assume_init()) }
} else {
return Err(serde::de::Error::missing_field("params"));
}
}

Ok(Request {
meta: RequestMeta::new(method.unwrap(), id.unwrap_or(Id::None)),
params: params.unwrap(),
})
}
}

deserializer.deserialize_map(Visitor(PhantomData))
}
}

/// A JSON-RPC 2.0 request object that has been serialized, with its [`Id`] and
/// method preserved.
///
Expand Down Expand Up @@ -285,3 +386,24 @@ impl Serialize for SerializedRequest {
self.request.serialize(serializer)
}
}

#[cfg(test)]
mod test {
use super::*;

fn test_inner<T: RpcObject + PartialEq>(t: T) {
let ser = serde_json::to_string(&t).unwrap();
let de: T = serde_json::from_str(&ser).unwrap();
let reser = serde_json::to_string(&de).unwrap();
assert_eq!(de, t, "deser error for {}", std::any::type_name::<T>());
assert_eq!(ser, reser, "reser error for {}", std::any::type_name::<T>());
}

#[test]
fn test_ser_deser() {
test_inner(Request::<()>::new("test", 1.into(), ()));
test_inner(Request::<u64>::new("test", "hello".to_string().into(), 1));
test_inner(Request::<String>::new("test", Id::None, "test".to_string()));
test_inner(Request::<Vec<u64>>::new("test", u64::MAX.into(), vec![1, 2, 3]));
}
}
41 changes: 41 additions & 0 deletions crates/json-rpc/src/support/axum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use crate::{ErrorPayload, Id, Request, Response, ResponsePayload, RpcObject};
use axum::extract;

impl From<extract::rejection::JsonRejection> for Response<(), ()> {
fn from(value: extract::rejection::JsonRejection) -> Self {
Response {
id: Id::None,
payload: ResponsePayload::Failure(ErrorPayload {
code: -32600,
message: value.to_string(),
data: None,
}),
}
}
}

impl<Payload, ErrData> axum::response::IntoResponse for Response<Payload, ErrData>
where
Payload: RpcObject,
ErrData: RpcObject,
{
fn into_response(self) -> axum::response::Response {
axum::response::IntoResponse::into_response(axum::response::Json(self))
}
}

#[async_trait::async_trait]
impl<S, Params> extract::FromRequest<S> for Request<Params>
where
axum::body::Bytes: extract::FromRequest<S>,
Params: RpcObject,
S: Send + Sync,
{
type Rejection = Response<(), ()>;

async fn from_request(req: extract::Request, state: &S) -> Result<Self, Self::Rejection> {
let json = extract::Json::<Request<Params>>::from_request(req, state).await?;

Ok(json.0)
}
}
2 changes: 2 additions & 0 deletions crates/json-rpc/src/support/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#[cfg(feature = "axum")]
mod axum;
Loading