Skip to content

Commit

Permalink
Merge pull request duckdb#12055 from lnkuiper/aggr_fixes
Browse files Browse the repository at this point in the history
Aggregation bugfixes
  • Loading branch information
Mytherin authored May 15, 2024
2 parents 3358b06 + 2b8f43b commit 53269a3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
13 changes: 11 additions & 2 deletions src/execution/aggregate_hashtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,8 @@ idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, V

idx_t new_group_count = 0;
idx_t remaining_entries = groups.size();
while (remaining_entries > 0) {
idx_t iteration_count;
for (iteration_count = 0; remaining_entries > 0 && iteration_count < capacity; iteration_count++) {
idx_t new_entry_count = 0;
idx_t need_compare_count = 0;
idx_t no_match_count = 0;
Expand All @@ -370,7 +371,9 @@ idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, V
const auto index = sel_vector->get_index(i);
const auto &salt = hash_salts[index];
auto &ht_offset = ht_offsets[index];
while (true) {

idx_t inner_iteration_count;
for (inner_iteration_count = 0; inner_iteration_count < capacity; inner_iteration_count++) {
auto &entry = entries[ht_offset];
if (entry.IsOccupied()) { // Cell is occupied: Compare salts
if (entry.GetSalt() == salt) {
Expand All @@ -393,6 +396,9 @@ idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, V
break;
}
}
if (inner_iteration_count == capacity) {
throw InternalException("Maximum inner iteration count reached in GroupedAggregateHashTable");
}
}

if (new_entry_count != 0) {
Expand Down Expand Up @@ -440,6 +446,9 @@ idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, V
sel_vector = &state.no_match_vector;
remaining_entries = no_match_count;
}
if (iteration_count == capacity) {
throw InternalException("Maximum outer iteration count reached in GroupedAggregateHashTable");
}

count += new_group_count;
return new_group_count;
Expand Down
9 changes: 6 additions & 3 deletions src/execution/radix_partitioned_hashtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ class RadixHTGlobalSinkState : public GlobalSinkState {
bool finalized;
//! Whether we are doing an external aggregation
atomic<bool> external;
//! Whether the aggregation is single-threaded
//! Threads that have called Sink
atomic<idx_t> active_threads;
//! Number of threads (from TaskScheduler)
const idx_t number_of_threads;
//! If any thread has called combine
atomic<bool> any_combined;
Expand All @@ -192,7 +194,7 @@ class RadixHTGlobalSinkState : public GlobalSinkState {

RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context_p, const RadixPartitionedHashTable &radix_ht_p)
: context(context_p), temporary_memory_state(TemporaryMemoryManager::Get(context).Register(context)),
radix_ht(radix_ht_p), config(context, *this), finalized(false), external(false),
radix_ht(radix_ht_p), config(context, *this), finalized(false), external(false), active_threads(0),
number_of_threads(NumericCast<idx_t>(TaskScheduler::GetScheduler(context).NumberOfThreads())),
any_combined(false), finalize_done(0), scan_pin_properties(TupleDataPinProperties::DESTROY_AFTER_DONE),
count_before_combining(0), max_partition_size(0) {
Expand Down Expand Up @@ -438,6 +440,7 @@ void RadixPartitionedHashTable::Sink(ExecutionContext &context, DataChunk &chunk
auto &lstate = input.local_state.Cast<RadixHTLocalSinkState>();
if (!lstate.ht) {
lstate.ht = CreateHT(context.client, gstate.config.sink_capacity, gstate.config.GetRadixBits());
gstate.active_threads++;
}

auto &group_chunk = lstate.group_chunk;
Expand Down Expand Up @@ -512,7 +515,7 @@ void RadixPartitionedHashTable::Finalize(ClientContext &context, GlobalSinkState
gstate.count_before_combining = uncombined_data.Count();

// If true there is no need to combine, it was all done by a single thread in a single HT
const auto single_ht = !gstate.external && gstate.number_of_threads == 1;
const auto single_ht = !gstate.external && gstate.active_threads == 1 && gstate.number_of_threads == 1;

auto &uncombined_partition_data = uncombined_data.GetPartitions();
const auto n_partitions = uncombined_partition_data.size();
Expand Down
13 changes: 13 additions & 0 deletions test/sql/aggregate/distinct/grouped/multiple_grouping_sets.test
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,16 @@ select sum(distinct value), count(*), course, type
10 7 NULL NULL
25 2 Math NULL
30 2 NULL Bachelor

# re-do the first query with one thread (internal issue 2046)
statement ok
set threads=1

query IIII
select course, type, count(*), sum(distinct value) from students group by course, type order by all;
----
CS NULL 2 -5
CS Bachelor 2 30
CS PhD 1 -20
Math NULL 1 15
Math Masters 1 10

0 comments on commit 53269a3

Please sign in to comment.