Skip to content

Commit

Permalink
fix: Count batch size via batch_aggregations rather than report_aggre…
Browse files Browse the repository at this point in the history
…gations (#3345)

* fix: Count batch size via batch_aggregations rather than report_aggregations

* Better test

* Update aggregator_core/src/datastore.rs

Co-authored-by: Brandon Pitman <[email protected]>

* Update aggregator_core/src/datastore.rs

Co-authored-by: Brandon Pitman <[email protected]>

---------

Co-authored-by: Brandon Pitman <[email protected]>
  • Loading branch information
inahga and branlwyd authored Aug 7, 2024
1 parent 3258350 commit a9eb6e2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 17 deletions.
37 changes: 21 additions & 16 deletions aggregator_core/src/datastore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4444,7 +4444,8 @@ ON CONFLICT(task_id, batch_identifier, aggregation_param) DO UPDATE
// Note that this ignores aggregation parameter, as `outstanding_batches` does not need to
// worry about aggregation parameters.
//
// TODO(#225): reevaluate whether we can ignore aggregation parameter here once we have experience with VDAFs requiring multiple aggregations per batch.
// TODO(#225): reevaluate whether we can ignore aggregation parameter here once we have
// experience with VDAFs requiring multiple aggregations per batch.
let stmt = self
.prepare_cached(
"-- put_outstanding_batch()
Expand Down Expand Up @@ -4570,7 +4571,7 @@ WHERE task_id = $1

try_join_all(rows.into_iter().map(|row| async move {
let batch_id = BatchId::get_decoded(row.get("batch_id"))?;
let size = self.read_batch_size(task_id, &batch_id).await?;
let size = self.read_batch_size(task_info.pkey, &batch_id).await?;
Ok(OutstandingBatch::new(*task_id, batch_id, size))
}))
.await
Expand All @@ -4583,34 +4584,38 @@ WHERE task_id = $1
// aggregations in the batch which are in a non-failure state (START/WAITING/FINISHED).
async fn read_batch_size(
&self,
task_id: &TaskId,
task_pkey: i64,
batch_id: &BatchId,
) -> Result<RangeInclusive<usize>, Error> {
// TODO(#1467): fix this to work in presence of GC.
let stmt = self
.prepare_cached(
"-- read_batch_size()
WITH batch_report_aggregation_statuses AS
(SELECT report_aggregations.state, COUNT(*) AS count FROM report_aggregations
JOIN aggregation_jobs
WITH report_aggregations_count AS (
SELECT COUNT(*) AS count FROM report_aggregations
JOIN aggregation_jobs
ON report_aggregations.aggregation_job_id = aggregation_jobs.id
WHERE aggregation_jobs.task_id = (SELECT id FROM tasks WHERE task_id = $1)
AND report_aggregations.task_id = aggregation_jobs.task_id
AND aggregation_jobs.batch_id = $2
GROUP BY report_aggregations.state)
WHERE aggregation_jobs.task_id = $1
AND report_aggregations.task_id = aggregation_jobs.task_id
AND aggregation_jobs.batch_id = $2
AND report_aggregations.state in ('START', 'WAITING')
),
batch_aggregation_count AS (
SELECT SUM(report_count) AS count FROM batch_aggregations
WHERE batch_aggregations.task_id = $1
AND batch_aggregations.batch_identifier = $2
)
SELECT
(SELECT SUM(count)::BIGINT FROM batch_report_aggregation_statuses
WHERE state IN ('FINISHED')) AS min_size,
(SELECT SUM(count)::BIGINT FROM batch_report_aggregation_statuses
WHERE state IN ('START', 'WAITING', 'FINISHED')) AS max_size",
(SELECT count FROM batch_aggregation_count)::BIGINT AS min_size,
(SELECT count FROM report_aggregations_count)::BIGINT
+ (SELECT count FROM batch_aggregation_count)::BIGINT AS max_size",
)
.await?;

let row = self
.query_one(
&stmt,
&[
/* task_id */ task_id.as_ref(),
/* task_id */ &task_pkey,
/* batch_id */ batch_id.as_ref(),
],
)
Expand Down
5 changes: 4 additions & 1 deletion aggregator_core/src/datastore/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5176,6 +5176,9 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) {
.unwrap(),
BatchAggregationState::Aggregating {
aggregate_share: Some(dummy::AggregateShare(0)),
// Let report_count be 1 without an accompanying report_aggregation in a
// terminal state. This captures the case where a FINISHED report_aggregation
// was garbage collected and no longer exists in the database.
report_count: 1,
checksum: ReportIdChecksum::default(),
aggregation_jobs_created: 4,
Expand Down Expand Up @@ -5429,7 +5432,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) {
Vec::from([OutstandingBatch::new(
task_id_2,
batch_id_2,
RangeInclusive::new(0, 1)
RangeInclusive::new(1, 2)
)])
);
assert_eq!(outstanding_batches_task_2_after_mark, Vec::new());
Expand Down

0 comments on commit a9eb6e2

Please sign in to comment.