Skip to content

Commit

Permalink
fix: Propagate error when building a HTTP request
Browse files Browse the repository at this point in the history
- Add error::Error variant
- Add test for an invalid request
  • Loading branch information
threema-donat committed Apr 26, 2024
1 parent 93dc6be commit e53715c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
62 changes: 35 additions & 27 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl Client {
/// See [ErrorReason](enum.ErrorReason.html) for possible errors.
#[cfg_attr(feature = "tracing", ::tracing::instrument)]
pub async fn send<T: PayloadLike>(&self, payload: T) -> Result<Response, Error> {
let request = self.build_request(payload);
let request = self.build_request(payload)?;
let requesting = self.http_client.request(request);

let response = requesting.await?;
Expand Down Expand Up @@ -152,7 +152,7 @@ impl Client {
}
}

fn build_request<T: PayloadLike>(&self, payload: T) -> hyper::Request<BoxBody<Bytes, Infallible>> {
fn build_request<T: PayloadLike>(&self, payload: T) -> Result<hyper::Request<BoxBody<Bytes, Infallible>>, Error> {
let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token());

let mut builder = hyper::Request::builder()
Expand Down Expand Up @@ -180,18 +180,16 @@ impl Client {
builder = builder.header("apns-topic", apns_topic.as_bytes());
}
if let Some(ref signer) = self.signer {
let auth = signer
.with_signature(|signature| format!("Bearer {}", signature))
.unwrap();
let auth = signer.with_signature(|signature| format!("Bearer {}", signature))?;

builder = builder.header(AUTHORIZATION, auth.as_bytes());
}

let payload_json = payload.to_json_string().unwrap();
let payload_json = payload.to_json_string()?;
builder = builder.header(CONTENT_LENGTH, format!("{}", payload_json.len()).as_bytes());

let request_body = Full::from(payload_json.into_bytes()).boxed();
builder.body(request_body).unwrap()
builder.body(request_body).map_err(Error::BuildRequestError)
}
}

Expand Down Expand Up @@ -247,7 +245,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());

assert_eq!("https://api.push.apple.com/3/device/a_test_id", &uri);
Expand All @@ -258,7 +256,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Sandbox);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());

assert_eq!("https://api.development.push.apple.com/3/device/a_test_id", &uri);
Expand All @@ -269,17 +267,27 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!(&Method::POST, request.method());
}

#[test]
fn test_request_invalid() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("\r\n", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);

assert!(matches!(request, Err(Error::BuildRequestError(_))));
}

#[test]
fn test_request_content_type() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap());
}
Expand All @@ -289,7 +297,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload.clone());
let request = client.build_request(payload.clone()).unwrap();
let payload_json = payload.to_json_string().unwrap();
let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap();

Expand All @@ -301,7 +309,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!(None, request.headers().get(AUTHORIZATION));
}
Expand All @@ -319,7 +327,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), Some(signer), Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_ne!(None, request.headers().get(AUTHORIZATION));
}
Expand All @@ -333,7 +341,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
};
let payload = builder.build("a_test_id", options);
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_push_type = request.headers().get("apns-push-type").unwrap();

assert_eq!("background", apns_push_type);
Expand All @@ -344,7 +352,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority");

assert_eq!(None, apns_priority);
Expand All @@ -363,7 +371,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();

assert_eq!("5", apns_priority);
Expand All @@ -382,7 +390,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();

assert_eq!("10", apns_priority);
Expand All @@ -395,7 +403,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id");

assert_eq!(None, apns_id);
Expand All @@ -414,7 +422,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id").unwrap();

assert_eq!("a-test-apns-id", apns_id);
Expand All @@ -427,7 +435,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration");

assert_eq!(None, apns_expiration);
Expand All @@ -446,7 +454,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration").unwrap();

assert_eq!("420", apns_expiration);
Expand All @@ -459,7 +467,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id");

assert_eq!(None, apns_collapse_id);
Expand All @@ -478,7 +486,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap();

assert_eq!("a_collapse_id", apns_collapse_id);
Expand All @@ -491,7 +499,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic");

assert_eq!(None, apns_topic);
Expand All @@ -510,7 +518,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic").unwrap();

assert_eq!("a_topic", apns_topic);
Expand All @@ -521,7 +529,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload.clone());
let request = client.build_request(payload.clone()).unwrap();

let body = request.into_body().collect().await.unwrap().to_bytes();
let body_str = String::from_utf8(body.to_vec()).unwrap();
Expand Down
4 changes: 4 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ pub enum Error {
#[error("Error building TLS config: {0}")]
Tls(#[from] rustls::Error),

/// Error while creating the HTTP request
#[error("Failed to construct HTTP request: {0}")]
BuildRequestError(#[source] http::Error),

/// Unexpected private key (only EC keys are supported).
#[cfg(all(not(feature = "openssl"), feature = "ring"))]
#[error("Unexpected private key: {0}")]
Expand Down

0 comments on commit e53715c

Please sign in to comment.