diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 2cfb19e00..e10d247d3 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -90,6 +90,7 @@ use prio::{ xof::XofTurboShake128, }, }; +use rayon::iter::{IndexedParallelIterator as _, IntoParallelRefIterator as _, ParallelIterator}; use reqwest::Client; use ring::{ digest::{digest, SHA256}, @@ -97,19 +98,15 @@ use ring::{ signature::{EcdsaKeyPair, Signature}, }; use std::{ - any::Any, collections::{HashMap, HashSet}, fmt::Debug, hash::Hash, - panic::{catch_unwind, resume_unwind, AssertUnwindSafe}, + panic, path::PathBuf, sync::{Arc, Mutex as SyncMutex}, time::{Duration as StdDuration, Instant}, }; -use tokio::{ - sync::oneshot::{self, error::RecvError}, - try_join, -}; +use tokio::{sync::mpsc, try_join}; use tracing::{debug, info, info_span, trace_span, warn, Level, Span}; use url::Url; @@ -1922,7 +1919,22 @@ impl VdafOps { } } - let prep_closure = { + // Compute the next aggregation step. + // + // We validate that each prepare_init can be represented by a `u64` ord value here, so that + // inside the parallel iterator we can unwrap. A conversion failure here will fail the + // entire aggregation. However, this is desirable: this can only happen if we receive too + // many report shares in an aggregation job for us to store, which is a whole-aggregation + // problem rather than a per-report problem. (separately, this would require more than + // u64::MAX report shares in a single aggregation job, which is practically impossible.) + u64::try_from(req.prepare_inits().len())?; + + // Shutdown on cancellation: if this request is cancelled, the `receiver` will be dropped. + // This will cause any attempts to send on `sender` to return a `SendError`, which will be + // returned from the function passed to `try_for_each_with`; `try_for_each_with` will + // terminate early on receiving an error. + let (sender, mut receiver) = mpsc::unbounded_channel(); + let producer_task = tokio::task::spawn_blocking({ let parent_span = Span::current(); let global_hpke_keypairs = global_hpke_keypairs.view(); let vdaf = Arc::clone(&vdaf); @@ -1932,345 +1944,331 @@ impl VdafOps { let aggregation_job_id = *aggregation_job_id; let verify_key = *verify_key; let agg_param = Arc::clone(&agg_param); - move || -> Result>, Error> { - let span = info_span!( - parent: parent_span, - "handle_aggregate_init_generic threadpool task" - ); - let _entered = span.enter(); - - // Decrypt shares & prepare initialization states. (§4.4.4.1) - let mut report_share_data = Vec::new(); - for (ord, prepare_init) in req.prepare_inits().iter().enumerate() { - // If decryption fails, then the aggregator MUST fail with error `hpke-decrypt-error`. (§4.4.2.2) - let global_hpke_keypair = global_hpke_keypairs.keypair( - prepare_init - .report_share() - .encrypted_input_share() - .config_id(), - ); - - let task_hpke_keypair = task.hpke_keys().get( - prepare_init - .report_share() - .encrypted_input_share() - .config_id(), - ); - let check_keypairs = if task_hpke_keypair.is_none() - && global_hpke_keypair.is_none() - { - debug!( - config_id = %prepare_init.report_share().encrypted_input_share().config_id(), - "Helper encrypted input share references unknown HPKE config ID" + move || { + let span = info_span!(parent: parent_span, "handle_aggregate_init_generic threadpool task"); + + req + .prepare_inits() + .par_iter() + .enumerate() + .try_for_each_with((sender, span), |(sender, span), (ord, prepare_init)| { + let _entered = span.enter(); + + // If decryption fails, then the aggregator MUST fail with error `hpke-decrypt-error`. (§4.4.2.2) + let global_hpke_keypair = global_hpke_keypairs.keypair( + prepare_init + .report_share() + .encrypted_input_share() + .config_id(), ); - metrics - .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "unknown_hpke_config_id")]); - Err(PrepareError::HpkeUnknownConfigId) - } else { - Ok(()) - }; - - let input_share_aad = check_keypairs.and_then(|_| { - InputShareAad::new( - *task.id(), - prepare_init.report_share().metadata().clone(), - prepare_init.report_share().public_share().to_vec(), - ) - .get_encoded() - .map_err(|err| { - debug!( - task_id = %task.id(), - metadata = ?prepare_init.report_share().metadata(), - ?err, - "Couldn't encode input share AAD" - ); - metrics.aggregate_step_failure_counter.add( - 1, - &[KeyValue::new("type", "input_share_aad_encode_failure")], - ); - // HpkeDecryptError isn't strictly accurate, but given that this - // fallible encoding is part of the HPKE decryption process, I think - // this is as close as we can get to a meaningful error signal. - PrepareError::HpkeDecryptError - }) - }); - let plaintext = input_share_aad.and_then(|input_share_aad| { - let try_hpke_open = |hpke_keypair| { - hpke::open( - hpke_keypair, - &HpkeApplicationInfo::new( - &Label::InputShare, - &Role::Client, - &Role::Helper, - ), - prepare_init.report_share().encrypted_input_share(), - &input_share_aad, - ) - }; + let task_hpke_keypair = task.hpke_keys().get( + prepare_init + .report_share() + .encrypted_input_share() + .config_id(), + ); - match (task_hpke_keypair, global_hpke_keypair) { - (None, None) => unreachable!("already checked this condition"), - (None, Some(global_hpke_keypair)) => { - try_hpke_open(&global_hpke_keypair) - } - (Some(task_hpke_keypair), None) => try_hpke_open(task_hpke_keypair), - (Some(task_hpke_keypair), Some(global_hpke_keypair)) => { - try_hpke_open(task_hpke_keypair).or_else(|error| match error { - // Only attempt second trial if _decryption_ fails, and not some - // error in server-side HPKE configuration. - hpke::Error::Hpke(_) => try_hpke_open(&global_hpke_keypair), - error => Err(error), - }) - } - } - .map_err(|error| { + let check_keypairs = if task_hpke_keypair.is_none() + && global_hpke_keypair.is_none() + { debug!( - task_id = %task.id(), - metadata = ?prepare_init.report_share().metadata(), - ?error, - "Couldn't decrypt helper's report share" + config_id = %prepare_init.report_share().encrypted_input_share().config_id(), + "Helper encrypted input share references unknown HPKE config ID" ); metrics .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "decrypt_failure")]); - PrepareError::HpkeDecryptError - }) - }); + .add(1, &[KeyValue::new("type", "unknown_hpke_config_id")]); + Err(PrepareError::HpkeUnknownConfigId) + } else { + Ok(()) + }; - let plaintext_input_share = plaintext.and_then(|plaintext| { - let plaintext_input_share = PlaintextInputShare::get_decoded(&plaintext) - .map_err(|error| { + let input_share_aad = check_keypairs.and_then(|_| { + InputShareAad::new( + *task.id(), + prepare_init.report_share().metadata().clone(), + prepare_init.report_share().public_share().to_vec(), + ) + .get_encoded() + .map_err(|err| { debug!( task_id = %task.id(), metadata = ?prepare_init.report_share().metadata(), - ?error, "Couldn't decode helper's plaintext input share", + ?err, + "Couldn't encode input share AAD" ); metrics.aggregate_step_failure_counter.add( 1, - &[KeyValue::new( - "type", - "plaintext_input_share_decode_failure", - )], + &[KeyValue::new("type", "input_share_aad_encode_failure")], ); - PrepareError::InvalidMessage - })?; - - // Build map of extension type to extension data, checking for duplicates. - let mut extensions = HashMap::new(); - if !plaintext_input_share.extensions().iter().all(|extension| { - extensions - .insert(*extension.extension_type(), extension.extension_data()) - .is_none() - }) { - debug!( - task_id = %task.id(), - metadata = ?prepare_init.report_share().metadata(), - "Received report share with duplicate extensions", - ); - metrics - .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "duplicate_extension")]); - return Err(PrepareError::InvalidMessage); - } + // HpkeDecryptError isn't strictly accurate, but given that this + // fallible encoding is part of the HPKE decryption process, I think + // this is as close as we can get to a meaningful error signal. + PrepareError::HpkeDecryptError + }) + }); + + let plaintext = input_share_aad.and_then(|input_share_aad| { + let try_hpke_open = |hpke_keypair| { + hpke::open( + hpke_keypair, + &HpkeApplicationInfo::new( + &Label::InputShare, + &Role::Client, + &Role::Helper, + ), + prepare_init.report_share().encrypted_input_share(), + &input_share_aad, + ) + }; - if require_taskprov_extension { - let valid_taskprov_extension_present = extensions - .get(&ExtensionType::Taskprov) - .map(|data| data.is_empty()) - .unwrap_or(false); - if !valid_taskprov_extension_present { + match (task_hpke_keypair, global_hpke_keypair) { + (None, None) => unreachable!("already checked this condition"), + (None, Some(global_hpke_keypair)) => { + try_hpke_open(&global_hpke_keypair) + } + (Some(task_hpke_keypair), None) => try_hpke_open(task_hpke_keypair), + (Some(task_hpke_keypair), Some(global_hpke_keypair)) => { + try_hpke_open(task_hpke_keypair).or_else(|error| match error { + // Only attempt second trial if _decryption_ fails, and not some + // error in server-side HPKE configuration. + hpke::Error::Hpke(_) => try_hpke_open(&global_hpke_keypair), + error => Err(error), + }) + } + } + .map_err(|error| { debug!( task_id = %task.id(), metadata = ?prepare_init.report_share().metadata(), - "Taskprov task received report with missing or malformed \ - taskprov extension", + ?error, + "Couldn't decrypt helper's report share" ); - metrics.aggregate_step_failure_counter.add( - 1, - &[KeyValue::new( - "type", - "missing_or_malformed_taskprov_extension", - )], + metrics + .aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "decrypt_failure")]); + PrepareError::HpkeDecryptError + }) + }); + + let plaintext_input_share = plaintext.and_then(|plaintext| { + let plaintext_input_share = PlaintextInputShare::get_decoded(&plaintext) + .map_err(|error| { + debug!( + task_id = %task.id(), + metadata = ?prepare_init.report_share().metadata(), + ?error, "Couldn't decode helper's plaintext input share", + ); + metrics.aggregate_step_failure_counter.add( + 1, + &[KeyValue::new( + "type", + "plaintext_input_share_decode_failure", + )], + ); + PrepareError::InvalidMessage + })?; + + // Build map of extension type to extension data, checking for duplicates. + let mut extensions = HashMap::new(); + if !plaintext_input_share.extensions().iter().all(|extension| { + extensions + .insert(*extension.extension_type(), extension.extension_data()) + .is_none() + }) { + debug!( + task_id = %task.id(), + metadata = ?prepare_init.report_share().metadata(), + "Received report share with duplicate extensions", ); + metrics + .aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "duplicate_extension")]); return Err(PrepareError::InvalidMessage); } - } else if extensions.contains_key(&ExtensionType::Taskprov) { - // taskprov not enabled, but the taskprov extension is present. - debug!( - task_id = %task.id(), - metadata = ?prepare_init.report_share().metadata(), - "Non-taskprov task received report with unexpected taskprov \ - extension", - ); - metrics - .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "unexpected_taskprov_extension")]); - return Err(PrepareError::InvalidMessage); - } - Ok(plaintext_input_share) - }); + if require_taskprov_extension { + let valid_taskprov_extension_present = extensions + .get(&ExtensionType::Taskprov) + .map(|data| data.is_empty()) + .unwrap_or(false); + if !valid_taskprov_extension_present { + debug!( + task_id = %task.id(), + metadata = ?prepare_init.report_share().metadata(), + "Taskprov task received report with missing or malformed \ + taskprov extension", + ); + metrics.aggregate_step_failure_counter.add( + 1, + &[KeyValue::new( + "type", + "missing_or_malformed_taskprov_extension", + )], + ); + return Err(PrepareError::InvalidMessage); + } + } else if extensions.contains_key(&ExtensionType::Taskprov) { + // taskprov not enabled, but the taskprov extension is present. + debug!( + task_id = %task.id(), + metadata = ?prepare_init.report_share().metadata(), + "Non-taskprov task received report with unexpected taskprov \ + extension", + ); + metrics + .aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "unexpected_taskprov_extension")]); + return Err(PrepareError::InvalidMessage); + } + + Ok(plaintext_input_share) + }); - let input_share = plaintext_input_share.and_then(|plaintext_input_share| { - A::InputShare::get_decoded_with_param( - &(&vdaf, Role::Helper.index().unwrap()), - plaintext_input_share.payload(), + let input_share = plaintext_input_share.and_then(|plaintext_input_share| { + A::InputShare::get_decoded_with_param( + &(&vdaf, Role::Helper.index().unwrap()), + plaintext_input_share.payload(), + ) + .map_err(|error| { + debug!( + task_id = %task.id(), + metadata = ?prepare_init.report_share().metadata(), + ?error, "Couldn't decode helper's input share", + ); + metrics + .aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "input_share_decode_failure")]); + PrepareError::InvalidMessage + }) + }); + + let public_share = A::PublicShare::get_decoded_with_param( + &vdaf, + prepare_init.report_share().public_share(), ) .map_err(|error| { debug!( task_id = %task.id(), metadata = ?prepare_init.report_share().metadata(), - ?error, "Couldn't decode helper's input share", + ?error, "Couldn't decode public share", ); metrics .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "input_share_decode_failure")]); + .add(1, &[KeyValue::new("type", "public_share_decode_failure")]); PrepareError::InvalidMessage - }) - }); - - let public_share = A::PublicShare::get_decoded_with_param( - &vdaf, - prepare_init.report_share().public_share(), - ) - .map_err(|error| { - debug!( - task_id = %task.id(), - metadata = ?prepare_init.report_share().metadata(), - ?error, "Couldn't decode public share", - ); - metrics - .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "public_share_decode_failure")]); - PrepareError::InvalidMessage - }); - - let shares = - input_share.and_then(|input_share| Ok((public_share?, input_share))); - - // Reject reports from too far in the future. - let shares = shares.and_then(|shares| { - if prepare_init - .report_share() - .metadata() - .time() - .is_after(&report_deadline) - { - return Err(PrepareError::ReportTooEarly); - } - Ok(shares) - }); - - // Next, the aggregator runs the preparation-state initialization algorithm for the VDAF - // associated with the task and computes the first state transition. [...] If either - // step fails, then the aggregator MUST fail with error `vdaf-prep-error`. (§4.4.2.2) - let init_rslt = shares.and_then(|(public_share, input_share)| { - trace_span!("VDAF preparation (helper initialization)").in_scope(|| { - vdaf.helper_initialized( - verify_key.as_bytes(), - &agg_param, - /* report ID is used as VDAF nonce */ - prepare_init.report_share().metadata().id().as_ref(), - &public_share, - &input_share, - prepare_init.message(), - ) - .and_then(|transition| transition.evaluate(&vdaf)) - .map_err(|error| { - handle_ping_pong_error( - task.id(), - Role::Helper, - prepare_init.report_share().metadata().id(), - error, - &metrics.aggregate_step_failure_counter, + }); + + let shares = + input_share.and_then(|input_share| Ok((public_share?, input_share))); + + // Reject reports from too far in the future. + let shares = shares.and_then(|shares| { + if prepare_init + .report_share() + .metadata() + .time() + .is_after(&report_deadline) + { + return Err(PrepareError::ReportTooEarly); + } + Ok(shares) + }); + + // Next, the aggregator runs the preparation-state initialization algorithm for the VDAF + // associated with the task and computes the first state transition. [...] If either + // step fails, then the aggregator MUST fail with error `vdaf-prep-error`. (§4.4.2.2) + let init_rslt = shares.and_then(|(public_share, input_share)| { + trace_span!("VDAF preparation (helper initialization)").in_scope(|| { + vdaf.helper_initialized( + verify_key.as_bytes(), + &agg_param, + /* report ID is used as VDAF nonce */ + prepare_init.report_share().metadata().id().as_ref(), + &public_share, + &input_share, + prepare_init.message(), ) + .and_then(|transition| transition.evaluate(&vdaf)) + .map_err(|error| { + handle_ping_pong_error( + task.id(), + Role::Helper, + prepare_init.report_share().metadata().id(), + error, + &metrics.aggregate_step_failure_counter, + ) + }) }) - }) - }); - - let (report_aggregation_state, prepare_step_result, output_share) = - match init_rslt { - Ok((PingPongState::Continued(prepare_state), outgoing_message)) => { - // Helper is not finished. Await the next message from the Leader to advance to - // the next step. - ( - ReportAggregationState::WaitingHelper { prepare_state }, + }); + + let (report_aggregation_state, prepare_step_result, output_share) = + match init_rslt { + Ok((PingPongState::Continued(prepare_state), outgoing_message)) => { + // Helper is not finished. Await the next message from the Leader to advance to + // the next step. + ( + ReportAggregationState::WaitingHelper { prepare_state }, + PrepareStepResult::Continue { + message: outgoing_message, + }, + None, + ) + } + Ok((PingPongState::Finished(output_share), outgoing_message)) => ( + ReportAggregationState::Finished, PrepareStepResult::Continue { message: outgoing_message, }, + Some(output_share), + ), + Err(prepare_error) => ( + ReportAggregationState::Failed { prepare_error }, + PrepareStepResult::Reject(prepare_error), None, - ) - } - Ok((PingPongState::Finished(output_share), outgoing_message)) => ( - ReportAggregationState::Finished, - PrepareStepResult::Continue { - message: outgoing_message, - }, - Some(output_share), - ), - Err(prepare_error) => ( - ReportAggregationState::Failed { prepare_error }, - PrepareStepResult::Reject(prepare_error), - None, - ), - }; - - report_share_data.push(ReportShareData { - report_share: prepare_init.report_share().clone(), - report_aggregation: WritableReportAggregation::new( - ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *prepare_init.report_share().metadata().id(), - *prepare_init.report_share().metadata().time(), - // A conversion failure here will fail the entire aggregation. - // However, this is desirable: this can only happen if we receive - // too many report shares in an aggregation job for us to store, - // which is a whole-aggregation problem rather than a per-report - // problem. (separately, this would require more than u64::MAX - // report shares in a single aggregation job, which is practically - // impossible.) - ord.try_into()?, - Some(PrepareResp::new( + ), + }; + + sender.send(ReportShareData { + report_share: prepare_init.report_share().clone(), + report_aggregation: WritableReportAggregation::new( + ReportAggregation::::new( + *task.id(), + aggregation_job_id, *prepare_init.report_share().metadata().id(), - prepare_step_result, - )), - report_aggregation_state, + *prepare_init.report_share().metadata().time(), + // Unwrap safety: we checked that all ordinal values are representable + // as a u64 before entering the parallel iterator. + ord.try_into().unwrap(), + Some(PrepareResp::new( + *prepare_init.report_share().metadata().id(), + prepare_step_result, + )), + report_aggregation_state, + ), + output_share, ), - output_share, - ), - }); - } - Ok(report_share_data) + }) + }) } - }; - let (sender, receiver) = oneshot::channel(); - rayon::spawn(|| { - // Do nothing if the result cannot be sent back. This would happen if the initiating - // future is cancelled. - // - // We need to catch panics here because rayon's default threadpool panic handler will - // abort, and it would be preferable to propagate the panic in the original future to - // avoid behavior changes. - // - // Using `AssertUnwindSafe` is OK here, because the only interior mutability we make use - // of is OTel instruments, which can't be put into inconsistent states merely by - // incrementing counters, and the global HPKE keypair cache, which we only read from. - // Using `AssertUnwindSafe` is easier than adding infectious `UnwindSafe` trait bounds - // to various VDAF associated types throughout the codebase. - let _ = sender.send(catch_unwind(AssertUnwindSafe(prep_closure))); }); - let report_share_data = receiver - .await - .map_err(|_recv_error: RecvError| { - Error::Internal("threadpool failed to send VDAF preparation result".into()) - })? - .unwrap_or_else(|panic_cause: Box| { - resume_unwind(panic_cause); - })?; + + let mut report_share_data = Vec::with_capacity(req.prepare_inits().len()); + while receiver.recv_many(&mut report_share_data, 10).await > 0 {} + + // Await the producer task to resume any panics that may have occurred, and to ensure we can + // unwrap the aggregation parameter's Arc in a few lines. The only other errors that can + // occur are: a `JoinError` indicating cancellation, which is impossible because we do not + // cancel the task; and a `SendError`, which can only happen if this future is cancelled (in + // which case we will not run this code at all). + let _ = producer_task.await.map_err(|join_error| { + if let Ok(reason) = join_error.try_into_panic() { + panic::resume_unwind(reason); + } + }); + assert_eq!(report_share_data.len(), req.prepare_inits().len()); // TODO: Use Arc::unwrap_or_clone() once the MSRV is at least 1.76.0. let agg_param = Arc::try_unwrap(agg_param).unwrap_or_else(|arc| arc.as_ref().clone()); diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index 1f566ef8f..6ea51e654 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -24,12 +24,9 @@ use prio::{ topology::ping_pong::{PingPongContinuedValue, PingPongState, PingPongTopology}, vdaf, }; -use std::{ - any::Any, - panic::{catch_unwind, resume_unwind, AssertUnwindSafe}, - sync::Arc, -}; -use tokio::sync::oneshot::{self, error::RecvError}; +use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use std::{panic, sync::Arc}; +use tokio::sync::mpsc; use tracing::{info_span, trace_span, Span}; impl VdafOps { @@ -59,161 +56,34 @@ impl VdafOps { { let request_step = req.step(); - let prep_closure = { - let parent_span = Span::current(); - let task_id = *task.id(); - let vdaf = Arc::clone(&vdaf); - let aggregation_parameter = aggregation_job.aggregation_parameter().clone(); - let metrics = metrics.clone(); - move || { - let span = info_span!(parent: parent_span, "step_aggregation_job threadpool task"); - let _entered = span.enter(); - - // Handle each transition in the request. - let mut report_aggregations_iter = report_aggregations.into_iter(); - let mut report_aggregations_to_write = Vec::new(); - for prep_step in req.prepare_steps() { - // Match preparation step received from leader to stored report aggregation, and extract - // the stored preparation step. - let report_aggregation = loop { - let report_agg = report_aggregations_iter.next().ok_or_else(|| { - datastore::Error::User( - Error::InvalidMessage( - Some(task_id), - "leader sent unexpected, duplicate, or out-of-order prepare \ - steps", - ) - .into(), - ) - })?; - if report_agg.report_id() != prep_step.report_id() { - // This report was omitted by the leader because of a prior failure. Note that - // the report was dropped (if it's not already in an error state) and continue. - if matches!( - report_agg.state(), - ReportAggregationState::WaitingHelper { .. } - ) { - report_aggregations_to_write.push(WritableReportAggregation::new( - report_agg - .with_state(ReportAggregationState::Failed { - prepare_error: PrepareError::ReportDropped, - }) - .with_last_prep_resp(None), - None, - )); - } - continue; - } - break report_agg; - }; - - let prep_state = match report_aggregation.state() { - ReportAggregationState::WaitingHelper { prepare_state } => prepare_state, - ReportAggregationState::WaitingLeader { .. } => { - return Err(datastore::Error::User( - Error::Internal( - "helper encountered unexpected \ - ReportAggregationState::WaitingLeader" - .to_string(), - ) - .into(), - )) - } - _ => { - return Err(datastore::Error::User( - Error::InvalidMessage( - Some(task_id), - "leader sent prepare step for non-WAITING report aggregation", - ) - .into(), - )); - } - }; - - let (report_aggregation_state, prepare_step_result, output_share) = - trace_span!("VDAF preparation (helper continuation)") - .in_scope(|| { - // Continue with the incoming message. - vdaf.helper_continued( - PingPongState::Continued(prep_state.clone()), - &aggregation_parameter, - prep_step.message(), - ) - .and_then(|continued_value| match continued_value { - PingPongContinuedValue::WithMessage { transition } => { - let (new_state, message) = - transition.evaluate(vdaf.as_ref())?; - let (report_aggregation_state, output_share) = - match new_state { - // Helper did not finish. Store the new state and - // await the next message from the Leader to advance - // preparation. - PingPongState::Continued(prepare_state) => ( - ReportAggregationState::WaitingHelper { - prepare_state, - }, - None, - ), - // Helper finished. Commit the output share. - PingPongState::Finished(output_share) => ( - ReportAggregationState::Finished, - Some(output_share), - ), - }; - - Ok(( - report_aggregation_state, - // Helper has an outgoing message for Leader - PrepareStepResult::Continue { message }, - output_share, - )) - } - PingPongContinuedValue::FinishedNoMessage { output_share } => { - Ok(( - ReportAggregationState::Finished, - PrepareStepResult::Finished, - Some(output_share), - )) - } - }) - }) - .map_err(|error| { - handle_ping_pong_error( - &task_id, - Role::Leader, - prep_step.report_id(), - error, - &metrics.aggregate_step_failure_counter, - ) - }) - .unwrap_or_else(|prepare_error| { - ( - ReportAggregationState::Failed { prepare_error }, - PrepareStepResult::Reject(prepare_error), - None, - ) - }); - - report_aggregations_to_write.push(WritableReportAggregation::new( - report_aggregation - .with_state(report_aggregation_state) - .with_last_prep_resp(Some(PrepareResp::new( - *prep_step.report_id(), - prepare_step_result, - ))), - output_share, - )); - } - - for report_aggregation in report_aggregations_iter { + // Match preparation step received from leader to stored report aggregation, and extract + // the stored preparation step. + let report_aggregation_count = report_aggregations.len(); + let mut report_aggregations_iter = report_aggregations.into_iter(); + + let mut prep_steps_and_ras = Vec::with_capacity(req.prepare_steps().len()); // matched to prep_steps + let mut report_aggregations_to_write = Vec::with_capacity(report_aggregation_count); + for prep_step in req.prepare_steps() { + let report_aggregation = loop { + let report_agg = report_aggregations_iter.next().ok_or_else(|| { + datastore::Error::User( + Error::InvalidMessage( + Some(*task.id()), + "leader sent unexpected, duplicate, or out-of-order prepare steps", + ) + .into(), + ) + })?; + if report_agg.report_id() != prep_step.report_id() { // This report was omitted by the leader because of a prior failure. Note that // the report was dropped (if it's not already in an error state) and continue. if matches!( - report_aggregation.state(), + report_agg.state(), ReportAggregationState::WaitingHelper { .. } ) { report_aggregations_to_write.push(WritableReportAggregation::new( - report_aggregation + report_agg + .clone() .with_state(ReportAggregationState::Failed { prepare_error: PrepareError::ReportDropped, }) @@ -221,39 +91,189 @@ impl VdafOps { None, )); } + continue; + } + break report_agg; + }; + + let prep_state = match report_aggregation.state() { + ReportAggregationState::WaitingHelper { prepare_state } => prepare_state.clone(), + ReportAggregationState::WaitingLeader { .. } => { + return Err(datastore::Error::User( + Error::Internal( + "helper encountered unexpected ReportAggregationState::WaitingLeader" + .to_string(), + ) + .into(), + )) + } + _ => { + return Err(datastore::Error::User( + Error::InvalidMessage( + Some(*task.id()), + "leader sent prepare step for non-WAITING report aggregation", + ) + .into(), + )); } - Ok(report_aggregations_to_write) + }; + + prep_steps_and_ras.push((prep_step.clone(), report_aggregation, prep_state)); + } + + for report_aggregation in report_aggregations_iter { + // This report was omitted by the leader because of a prior failure. Note that + // the report was dropped (if it's not already in an error state) and continue. + if matches!( + report_aggregation.state(), + ReportAggregationState::WaitingHelper { .. } + ) { + report_aggregations_to_write.push(WritableReportAggregation::new( + report_aggregation + .clone() + .with_state(ReportAggregationState::Failed { + prepare_error: PrepareError::ReportDropped, + }) + .with_last_prep_resp(None), + None, + )); + } + } + + // Compute the next aggregation step. + // + // Shutdown on cancellation: if this request is cancelled, the `receiver` will be dropped. + // This will cause any attempts to send on `sender` to return a `SendError`, which will be + // returned from the function passed to `try_for_each_with`; `try_for_each_with` will + // terminate early on receiving an error. + let (sender, mut receiver) = mpsc::unbounded_channel(); + let aggregation_job = Arc::new(aggregation_job); + let producer_task = tokio::task::spawn_blocking({ + let parent_span = Span::current(); + let metrics = metrics.clone(); + let task = Arc::clone(&task); + let vdaf = Arc::clone(&vdaf); + let aggregation_job = Arc::clone(&aggregation_job); + + move || { + let span = info_span!(parent: parent_span, "step_aggregation_job threadpool task"); + + prep_steps_and_ras.into_par_iter().try_for_each_with( + (sender, span), + |(sender, span), (prep_step, report_aggregation, prep_state)| { + let _entered = span.enter(); + + let (report_aggregation_state, prepare_step_result, output_share) = + trace_span!("VDAF preparation (helper continuation)") + .in_scope(|| { + // Continue with the incoming message. + vdaf.helper_continued( + PingPongState::Continued(prep_state.clone()), + aggregation_job.aggregation_parameter(), + prep_step.message(), + ) + .and_then( + |continued_value| { + match continued_value { + PingPongContinuedValue::WithMessage { + transition, + } => { + let (new_state, message) = + transition.evaluate(vdaf.as_ref())?; + let (report_aggregation_state, output_share) = + match new_state { + // Helper did not finish. Store the new + // state and await the next message from + // the Leader to advance preparation. + PingPongState::Continued(prepare_state) => ( + ReportAggregationState::WaitingHelper { + prepare_state, + }, + None, + ), + // Helper finished. Commit the output + // share. + PingPongState::Finished(output_share) => ( + ReportAggregationState::Finished, + Some(output_share), + ), + }; + + Ok(( + report_aggregation_state, + // Helper has an outgoing message for Leader + PrepareStepResult::Continue { message }, + output_share, + )) + } + + PingPongContinuedValue::FinishedNoMessage { + output_share, + } => Ok(( + ReportAggregationState::Finished, + PrepareStepResult::Finished, + Some(output_share), + )), + } + }, + ) + }) + .map_err(|error| { + handle_ping_pong_error( + task.id(), + Role::Leader, + prep_step.report_id(), + error, + &metrics.aggregate_step_failure_counter, + ) + }) + .unwrap_or_else(|prepare_error| { + ( + ReportAggregationState::Failed { prepare_error }, + PrepareStepResult::Reject(prepare_error), + None, + ) + }); + + sender.send(WritableReportAggregation::new( + report_aggregation + .clone() + .with_state(report_aggregation_state) + .with_last_prep_resp(Some(PrepareResp::new( + *prep_step.report_id(), + prepare_step_result, + ))), + output_share, + )) + }, + ) } - }; - let (sender, receiver) = oneshot::channel(); - rayon::spawn(|| { - // Do nothing if the result cannot be sent back. This would happen if the initiating - // future is cancelled. - // - // We need to catch panics here because rayon's default threadpool panic handler will - // abort, and it would be preferable to propagate the panic in the original future to - // avoid behavior changes. - // - // Using `AssertUnwindSafe` is OK here, because the only interior mutability we make use - // of is OTel instruments, and those can't be put into inconsistent states merely by - // incrementing counters. Using `AssertUnwindSafe` is easier than adding infectious - // `UnwindSafe` trait bounds to various VDAF associated types throughout the codebase. - let _ = sender.send(catch_unwind(AssertUnwindSafe(prep_closure))); }); - let report_aggregations_to_write = receiver + + while receiver + .recv_many(&mut report_aggregations_to_write, 10) .await - .map_err(|_recv_error: RecvError| { - datastore::Error::User( - Error::Internal("threadpool failed to send VDAF preparation result".into()) - .into(), - ) - })? - .unwrap_or_else(|panic_cause: Box| resume_unwind(panic_cause))?; + > 0 + {} + + // Await the producer task to resume any panics that may have occurred, and to ensure we can + // unwrap the aggregation job's Arc in a few lines. The only other errors that can occur + // are: a `JoinError` indicating cancellation, which is impossible because we do not cancel + // the task; and a `SendError`, which can only happen if this future is cancelled (in which + // case we will not run this code at all). + let _ = producer_task.await.map_err(|join_error| { + if let Ok(reason) = join_error.try_into_panic() { + panic::resume_unwind(reason); + } + }); + assert_eq!(report_aggregations_to_write.len(), report_aggregation_count); // Write accumulated aggregation values back to the datastore; this will mark any reports // that can't be aggregated because the batch is collected with error BatchCollected. let aggregation_job_id = *aggregation_job.id(); - let aggregation_job = aggregation_job + // TODO: Use Arc::unwrap_or_clone() once the MSRV is at least 1.76.0. + let aggregation_job = Arc::try_unwrap(aggregation_job) + .unwrap_or_else(|arc| arc.as_ref().clone()) .with_step(request_step) // Advance the job to the leader's step .with_last_request_hash(request_hash); let mut aggregation_job_writer = diff --git a/aggregator/src/aggregator/aggregation_job_writer.rs b/aggregator/src/aggregator/aggregation_job_writer.rs index 059919a0a..d322fbcf4 100644 --- a/aggregator/src/aggregator/aggregation_job_writer.rs +++ b/aggregator/src/aggregator/aggregation_job_writer.rs @@ -120,10 +120,12 @@ where pub fn put( &mut self, aggregation_job: AggregationJob, - report_aggregations: Vec, + mut report_aggregations: Vec, ) -> Result<(), Error> { self.update_aggregation_parameter(aggregation_job.aggregation_parameter()); + report_aggregations.sort_unstable_by_key(RA::ord); + // Compute batch identifiers first, since computing the batch identifier is fallible and // it's nicer to not have to unwind state modifications if we encounter an error. let batch_identifiers = report_aggregations @@ -817,6 +819,9 @@ impl> pub trait ReportAggregationUpdate>: Clone + Send + Sync { + /// Returns the order of this report aggregation in its aggregation job. + fn ord(&self) -> u64; + /// Returns the report ID associated with this report aggregation. fn report_id(&self) -> &ReportId; @@ -862,6 +867,10 @@ where A::PrepareMessage: Send + Sync, A::PublicShare: Send + Sync, { + fn ord(&self) -> u64 { + self.report_aggregation.ord() + } + fn report_id(&self) -> &ReportId { self.report_aggregation.report_id() } @@ -922,6 +931,10 @@ impl ReportAggregationUpdate for Report where A: vdaf::Aggregator, { + fn ord(&self) -> u64 { + self.ord() + } + fn report_id(&self) -> &ReportId { self.report_id() } diff --git a/aggregator/src/aggregator/error.rs b/aggregator/src/aggregator/error.rs index 4f1ec94c0..f99feb46c 100644 --- a/aggregator/src/aggregator/error.rs +++ b/aggregator/src/aggregator/error.rs @@ -143,6 +143,8 @@ pub enum Error { /// An error occurred when trying to ensure differential privacy. #[error("differential privacy error: {0}")] DifferentialPrivacy(VdafError), + #[error("client disconnected")] + ClientDisconnected, } /// A newtype around `Arc`. This is needed to host a customized implementation of @@ -309,6 +311,7 @@ impl Error { Error::BadRequest(_) => "bad_request", Error::InvalidTask(_, _) => "invalid_task", Error::DifferentialPrivacy(_) => "differential_privacy", + Error::ClientDisconnected => "client_disconnected", } } } diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 905fb4dae..2d32e1f5f 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -153,6 +153,7 @@ async fn run_error_handler(error: &Error, mut conn: Conn) -> Conn { &ProblemDocument::new_dap(DapProblemType::InvalidTask).with_task_id(task_id), ), Error::DifferentialPrivacy(_) => conn.with_status(Status::InternalServerError), + Error::ClientDisconnected => conn, }; if matches!(conn.status(), Some(status) if status.is_server_error()) { @@ -373,9 +374,12 @@ async fn hpke_config( ) -> Result<(), Error> { let query = serde_urlencoded::from_str::(conn.querystring()) .map_err(|err| Error::BadRequest(format!("couldn't parse query string: {err}")))?; - let (encoded_hpke_config_list, signature) = aggregator - .handle_hpke_config(query.task_id.as_ref().map(AsRef::as_ref)) - .await?; + let (encoded_hpke_config_list, signature) = conn + .cancel_on_disconnect( + aggregator.handle_hpke_config(query.task_id.as_ref().map(AsRef::as_ref)), + ) + .await + .ok_or(Error::ClientDisconnected)??; // Handle CORS, if the request header is present. if let Some(origin) = conn.request_headers().get(KnownHeaderName::Origin) { @@ -425,7 +429,9 @@ async fn upload( validate_content_type(conn, Report::MEDIA_TYPE).map_err(Arc::new)?; let task_id = parse_task_id(conn).map_err(Arc::new)?; - aggregator.handle_upload(&task_id, &body).await?; + conn.cancel_on_disconnect(aggregator.handle_upload(&task_id, &body)) + .await + .ok_or(Arc::new(Error::ClientDisconnected))??; // Handle CORS, if the request header is present. if let Some(origin) = conn.request_headers().get(KnownHeaderName::Origin) { @@ -471,15 +477,16 @@ async fn aggregation_jobs_put( let aggregation_job_id = parse_aggregation_job_id(conn)?; let auth_token = parse_auth_token(&task_id, conn)?; let taskprov_task_config = parse_taskprov_header(&aggregator, &task_id, conn)?; - let response = aggregator - .handle_aggregate_init( + let response = conn + .cancel_on_disconnect(aggregator.handle_aggregate_init( &task_id, &aggregation_job_id, &body, auth_token, taskprov_task_config.as_ref(), - ) - .await?; + )) + .await + .ok_or(Error::ClientDisconnected)??; Ok(EncodedBody::new(response, AggregationJobResp::MEDIA_TYPE)) } @@ -495,15 +502,16 @@ async fn aggregation_jobs_post( let aggregation_job_id = parse_aggregation_job_id(conn)?; let auth_token = parse_auth_token(&task_id, conn)?; let taskprov_task_config = parse_taskprov_header(&aggregator, &task_id, conn)?; - let response = aggregator - .handle_aggregate_continue( + let response = conn + .cancel_on_disconnect(aggregator.handle_aggregate_continue( &task_id, &aggregation_job_id, &body, auth_token, taskprov_task_config.as_ref(), - ) - .await?; + )) + .await + .ok_or(Error::ClientDisconnected)??; Ok(EncodedBody::new(response, AggregationJobResp::MEDIA_TYPE)) } @@ -518,14 +526,14 @@ async fn aggregation_jobs_delete( let auth_token = parse_auth_token(&task_id, conn)?; let taskprov_task_config = parse_taskprov_header(&aggregator, &task_id, conn)?; - aggregator - .handle_aggregate_delete( - &task_id, - &aggregation_job_id, - auth_token, - taskprov_task_config.as_ref(), - ) - .await?; + conn.cancel_on_disconnect(aggregator.handle_aggregate_delete( + &task_id, + &aggregation_job_id, + auth_token, + taskprov_task_config.as_ref(), + )) + .await + .ok_or(Error::ClientDisconnected)??; Ok(Status::NoContent) } @@ -539,9 +547,14 @@ async fn collection_jobs_put( let task_id = parse_task_id(conn)?; let collection_job_id = parse_collection_job_id(conn)?; let auth_token = parse_auth_token(&task_id, conn)?; - aggregator - .handle_create_collection_job(&task_id, &collection_job_id, &body, auth_token) - .await?; + conn.cancel_on_disconnect(aggregator.handle_create_collection_job( + &task_id, + &collection_job_id, + &body, + auth_token, + )) + .await + .ok_or(Error::ClientDisconnected)??; Ok(Status::Created) } @@ -554,9 +567,14 @@ async fn collection_jobs_post( let task_id = parse_task_id(conn)?; let collection_job_id = parse_collection_job_id(conn)?; let auth_token = parse_auth_token(&task_id, conn)?; - let response_opt = aggregator - .handle_get_collection_job(&task_id, &collection_job_id, auth_token) - .await?; + let response_opt = conn + .cancel_on_disconnect(aggregator.handle_get_collection_job( + &task_id, + &collection_job_id, + auth_token, + )) + .await + .ok_or(Error::ClientDisconnected)??; match response_opt { Some(response_bytes) => { conn.response_headers_mut().insert( @@ -579,9 +597,13 @@ async fn collection_jobs_delete( let task_id = parse_task_id(conn)?; let collection_job_id = parse_collection_job_id(conn)?; let auth_token = parse_auth_token(&task_id, conn)?; - aggregator - .handle_delete_collection_job(&task_id, &collection_job_id, auth_token) - .await?; + conn.cancel_on_disconnect(aggregator.handle_delete_collection_job( + &task_id, + &collection_job_id, + auth_token, + )) + .await + .ok_or(Error::ClientDisconnected)??; Ok(Status::NoContent) } @@ -595,9 +617,15 @@ async fn aggregate_shares( let task_id = parse_task_id(conn)?; let auth_token = parse_auth_token(&task_id, conn)?; let taskprov_task_config = parse_taskprov_header(&aggregator, &task_id, conn)?; - let share = aggregator - .handle_aggregate_share(&task_id, &body, auth_token, taskprov_task_config.as_ref()) - .await?; + let share = conn + .cancel_on_disconnect(aggregator.handle_aggregate_share( + &task_id, + &body, + auth_token, + taskprov_task_config.as_ref(), + )) + .await + .ok_or(Error::ClientDisconnected)??; Ok(EncodedBody::new(share, AggregateShare::MEDIA_TYPE)) }