Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More aggregator test cleanup #1794

Merged
merged 2 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions aggregator/src/aggregator/aggregation_job_continue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,13 @@ impl VdafOps {
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod test_util {
use crate::aggregator::http_handlers::test_util::{take_problem_details, take_response_body};
use crate::aggregator::http_handlers::test_util::{decode_response_body, take_problem_details};
use janus_aggregator_core::task::Task;
use janus_messages::{AggregationJobContinueReq, AggregationJobId, AggregationJobResp};
use prio::codec::{Decode, Encode};
use prio::codec::Encode;
use serde_json::json;
use trillium::{Handler, KnownHeaderName, Status};
use trillium_testing::{prelude::post, TestConn};
use trillium_testing::{assert_headers, prelude::post, TestConn};

async fn post_aggregation_job(
task: &Task,
Expand Down Expand Up @@ -315,15 +315,8 @@ pub mod test_util {
let mut test_conn = post_aggregation_job(task, aggregation_job_id, request, handler).await;

assert_eq!(test_conn.status(), Some(Status::Ok));
assert_eq!(
test_conn
.response_headers()
.get(KnownHeaderName::ContentType)
.unwrap(),
AggregationJobResp::MEDIA_TYPE
);
let body_bytes = take_response_body(&mut test_conn).await;
AggregationJobResp::get_decoded(&body_bytes).unwrap()
assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE));
decode_response_body::<AggregationJobResp>(&mut test_conn).await
}

pub async fn post_aggregation_job_expecting_status(
Expand Down Expand Up @@ -358,13 +351,6 @@ pub mod test_util {
)
.await;

assert_eq!(
test_conn
.response_headers()
.get(KnownHeaderName::ContentType)
.unwrap(),
"application/problem+json"
);
assert_eq!(
take_problem_details(&mut test_conn).await,
json!({
Expand Down
22 changes: 4 additions & 18 deletions aggregator/src/aggregator/collection_job_tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::aggregator::{
http_handlers::{
aggregator_handler,
test_util::{take_problem_details, take_response_body},
test_util::{decode_response_body, take_problem_details},
},
Config,
};
Expand Down Expand Up @@ -42,6 +42,7 @@ use serde_json::json;
use std::sync::Arc;
use trillium::{Handler, KnownHeaderName, Status};
use trillium_testing::{
assert_headers,
prelude::{post, put},
TestConn,
};
Expand Down Expand Up @@ -352,18 +353,10 @@ async fn collection_job_success_fixed_size() {
}

let mut test_conn = test_case.post_collection_job(&collection_job_id).await;

assert_eq!(test_conn.status(), Some(Status::Ok));
assert_eq!(
test_conn
.response_headers()
.get(KnownHeaderName::ContentType)
.unwrap(),
Collection::<FixedSize>::MEDIA_TYPE
);
let body_bytes = take_response_body(&mut test_conn).await;
let collect_resp = Collection::<FixedSize>::get_decoded(body_bytes.as_ref()).unwrap();
assert_headers!(&test_conn, "content-type" => (Collection::<FixedSize>::MEDIA_TYPE));

let collect_resp: Collection<FixedSize> = decode_response_body(&mut test_conn).await;
assert_eq!(
collect_resp.report_count(),
test_case.task.min_batch_size() + 1
Expand Down Expand Up @@ -423,13 +416,6 @@ async fn collection_job_success_fixed_size() {
.put_collection_job(&collection_job_id, &request)
.await;
assert_eq!(test_conn.status(), Some(Status::BadRequest));
assert_eq!(
test_conn
.response_headers()
.get(KnownHeaderName::ContentType)
.unwrap(),
"application/problem+json"
);
assert_eq!(
take_problem_details(&mut test_conn).await,
json!({
Expand Down
91 changes: 31 additions & 60 deletions aggregator/src/aggregator/http_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,11 @@ fn parse_taskprov_header<C: Clock>(
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod test_util {
use janus_messages::codec::Decode;
use std::borrow::Cow;
use trillium_testing::TestConn;
use trillium_testing::{assert_headers, TestConn};

pub async fn take_response_body(test_conn: &mut TestConn) -> Cow<'_, [u8]> {
async fn take_response_body(test_conn: &mut TestConn) -> Cow<'_, [u8]> {
test_conn
.take_response_body()
.unwrap()
Expand All @@ -625,7 +626,12 @@ pub mod test_util {
.unwrap()
}

pub async fn decode_response_body<T: Decode>(test_conn: &mut TestConn) -> T {
T::get_decoded(&take_response_body(test_conn).await).unwrap()
}

pub async fn take_problem_details(test_conn: &mut TestConn) -> serde_json::Value {
assert_headers!(&test_conn, "content-type" => "application/problem+json");
serde_json::from_slice(&take_response_body(test_conn).await).unwrap()
}
}
Expand All @@ -642,7 +648,7 @@ mod tests {
empty_batch_aggregations,
http_handlers::{
aggregator_handler, aggregator_handler_with_aggregator,
test_util::{take_problem_details, take_response_body},
test_util::{decode_response_body, take_problem_details},
},
tests::{
create_report, create_report_custom, default_aggregator_config,
Expand Down Expand Up @@ -703,7 +709,7 @@ mod tests {
};
use rand::random;
use serde_json::json;
use std::{collections::HashMap, io::Cursor, sync::Arc, time::Duration as StdDuration};
use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration};
use trillium::{Handler, KnownHeaderName, Status};
use trillium_testing::{
assert_headers,
Expand Down Expand Up @@ -754,13 +760,6 @@ mod tests {
// No task ID provided and no global keys are configured.
let mut test_conn = get("/hpke_config").run_async(&handler).await;
assert_eq!(test_conn.status(), Some(Status::BadRequest));
assert_eq!(
test_conn
.response_headers()
.get(KnownHeaderName::ContentType)
.unwrap(),
"application/problem+json"
);
assert_eq!(
take_problem_details(&mut test_conn).await,
json!({
Expand All @@ -777,13 +776,6 @@ mod tests {
// Expected status and problem type should be per the protocol
// https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.3.1
assert_eq!(test_conn.status(), Some(Status::BadRequest));
assert_eq!(
test_conn
.response_headers()
.get(KnownHeaderName::ContentType)
.unwrap(),
"application/problem+json"
);
assert_eq!(
take_problem_details(&mut test_conn).await,
json!({
Expand All @@ -806,8 +798,7 @@ mod tests {
"content-type" => (HpkeConfigList::MEDIA_TYPE),
);

let bytes = take_response_body(&mut test_conn).await;
let hpke_config_list = HpkeConfigList::decode(&mut Cursor::new(&bytes)).unwrap();
let hpke_config_list: HpkeConfigList = decode_response_body(&mut test_conn).await;
assert_eq!(
hpke_config_list.hpke_configs(),
&[want_hpke_key.config().clone()]
Expand Down Expand Up @@ -862,8 +853,7 @@ mod tests {
"cache-control" => "max-age=86400",
"content-type" => (HpkeConfigList::MEDIA_TYPE),
);
let bytes = take_response_body(&mut test_conn).await;
let hpke_config_list = HpkeConfigList::decode(&mut Cursor::new(&bytes)).unwrap();
let hpke_config_list: HpkeConfigList = decode_response_body(&mut test_conn).await;
assert_eq!(
hpke_config_list.hpke_configs(),
&[first_hpke_keypair.config().clone()]
Expand All @@ -882,8 +872,7 @@ mod tests {
aggregator.refresh_caches().await.unwrap();
let mut test_conn = get("/hpke_config").run_async(&handler).await;
assert_eq!(test_conn.status(), Some(Status::Ok));
let bytes = take_response_body(&mut test_conn).await;
let hpke_config_list = HpkeConfigList::decode(&mut Cursor::new(&bytes)).unwrap();
let hpke_config_list: HpkeConfigList = decode_response_body(&mut test_conn).await;
assert_eq!(
hpke_config_list.hpke_configs(),
&[first_hpke_keypair.config().clone()]
Expand All @@ -903,8 +892,7 @@ mod tests {
aggregator.refresh_caches().await.unwrap();
let mut test_conn = get("/hpke_config").run_async(&handler).await;
assert_eq!(test_conn.status(), Some(Status::Ok));
let bytes = take_response_body(&mut test_conn).await;
let hpke_config_list = HpkeConfigList::decode(&mut Cursor::new(&bytes)).unwrap();
let hpke_config_list: HpkeConfigList = decode_response_body(&mut test_conn).await;
// Unordered comparison.
assert_eq!(
HashMap::from_iter(
Expand Down Expand Up @@ -939,8 +927,7 @@ mod tests {
aggregator.refresh_caches().await.unwrap();
let mut test_conn = get("/hpke_config").run_async(&handler).await;
assert_eq!(test_conn.status(), Some(Status::Ok));
let bytes = take_response_body(&mut test_conn).await;
let hpke_config_list = HpkeConfigList::decode(&mut Cursor::new(&bytes)).unwrap();
let hpke_config_list: HpkeConfigList = decode_response_body(&mut test_conn).await;
assert_eq!(
hpke_config_list.hpke_configs(),
&[first_hpke_keypair.config().clone()]
Expand Down Expand Up @@ -1027,8 +1014,7 @@ mod tests {
.run_async(&handler)
.await;
assert_eq!(test_conn.status(), Some(Status::Ok));
let bytes = take_response_body(&mut test_conn).await;
let hpke_config_list = HpkeConfigList::decode(&mut Cursor::new(&bytes)).unwrap();
let hpke_config_list: HpkeConfigList = decode_response_body(&mut test_conn).await;
assert_eq!(
hpke_config_list.hpke_configs(),
&[first_hpke_keypair.config().clone()]
Expand Down Expand Up @@ -1441,18 +1427,12 @@ mod tests {
);

// Check that CORS headers don't bleed over to other routes.
assert!(test_conn
.response_headers()
.get("access-control-allow-origin")
.is_none());
assert!(test_conn
.response_headers()
.get("access-control-allow-methods")
.is_none());
assert!(test_conn
.response_headers()
.get("access-control-max-age")
.is_none());
assert_headers!(
&test_conn,
"access-control-allow-origin" => None,
"access-control-allow-methods" => None,
"access-control-max-age" => None,
);

let test_conn = TestConn::build(
trillium::Method::Options,
Expand All @@ -1465,10 +1445,7 @@ mod tests {
.with_request_header(KnownHeaderName::AccessControlRequestMethod, "PUT")
.run_async(&handler)
.await;
assert!(test_conn
.response_headers()
.get(KnownHeaderName::AccessControlAllowMethods)
.is_none());
assert_headers!(&test_conn, "access-control-allow-methods" => None);
}

#[tokio::test]
Expand Down Expand Up @@ -1918,8 +1895,7 @@ mod tests {
&test_conn,
"content-type" => (AggregationJobResp::MEDIA_TYPE)
);
let body_bytes = take_response_body(&mut test_conn).await;
let aggregate_resp = AggregationJobResp::get_decoded(&body_bytes).unwrap();
let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await;

// Validate response.
assert_eq!(aggregate_resp.prepare_steps().len(), 9);
Expand Down Expand Up @@ -2214,8 +2190,7 @@ mod tests {
let mut test_conn =
put_aggregation_job(&task, &aggregation_job_id, &request, &handler).await;
assert_eq!(test_conn.status(), Some(Status::Ok));
let body_bytes = take_response_body(&mut test_conn).await;
let aggregate_resp = AggregationJobResp::get_decoded(&body_bytes).unwrap();
let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await;

// Validate response.
assert_eq!(aggregate_resp.prepare_steps().len(), 4);
Expand Down Expand Up @@ -2294,8 +2269,7 @@ mod tests {
let mut test_conn =
put_aggregation_job(&test_case.task, &random(), &request, &test_case.handler).await;
assert_eq!(test_conn.status(), Some(Status::Ok));
let body_bytes = take_response_body(&mut test_conn).await;
let aggregate_resp = AggregationJobResp::get_decoded(&body_bytes).unwrap();
let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await;

assert_eq!(aggregate_resp.prepare_steps().len(), 1);

Expand Down Expand Up @@ -2370,8 +2344,7 @@ mod tests {
&test_conn,
"content-type" => (AggregationJobResp::MEDIA_TYPE)
);
let body_bytes = take_response_body(&mut test_conn).await;
let aggregate_resp = AggregationJobResp::get_decoded(&body_bytes).unwrap();
let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await;

// Validate response.
assert_eq!(aggregate_resp.prepare_steps().len(), 1);
Expand Down Expand Up @@ -2425,8 +2398,7 @@ mod tests {
&test_conn,
"content-type" => (AggregationJobResp::MEDIA_TYPE)
);
let body_bytes = take_response_body(&mut test_conn).await;
let aggregate_resp = AggregationJobResp::get_decoded(&body_bytes).unwrap();
let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await;

// Validate response.
assert_eq!(aggregate_resp.prepare_steps().len(), 1);
Expand Down Expand Up @@ -4436,8 +4408,7 @@ mod tests {
&test_conn,
"content-type" => (Collection::<TimeInterval>::MEDIA_TYPE)
);
let body_bytes = take_response_body(&mut test_conn).await;
let collect_resp = Collection::<TimeInterval>::get_decoded(body_bytes.as_ref()).unwrap();
let collect_resp: Collection<TimeInterval> = decode_response_body(&mut test_conn).await;

assert_eq!(collect_resp.report_count(), 12);
assert_eq!(collect_resp.interval(), &batch_interval);
Expand Down Expand Up @@ -5181,8 +5152,8 @@ mod tests {
&test_conn,
"content-type" => (AggregateShareMessage::MEDIA_TYPE)
);
let body_bytes = take_response_body(&mut test_conn).await;
let aggregate_share_resp = AggregateShareMessage::get_decoded(&body_bytes).unwrap();
let aggregate_share_resp: AggregateShareMessage =
decode_response_body(&mut test_conn).await;

let aggregate_share = hpke::open(
collector_hpke_keypair.config(),
Expand Down
Loading
Loading