Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize COUNT( DISTINCT ...) for strings (up to 9x faster) #8849

Merged
merged 37 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9c44d04
chkp
jayzhan211 Jan 10, 2024
6cb8bbe
chkp
jayzhan211 Jan 13, 2024
9d662a7
draft
jayzhan211 Jan 13, 2024
1744cb3
iter done
jayzhan211 Jan 13, 2024
e3b0568
short string test
jayzhan211 Jan 13, 2024
12cf50c
add test
jayzhan211 Jan 13, 2024
4f9a3f0
remove unused
jayzhan211 Jan 13, 2024
626b1cb
to_string directly
jayzhan211 Jan 13, 2024
2e80cb7
rewrite evaluate
jayzhan211 Jan 13, 2024
d2d1d6d
return Vec<String>
jayzhan211 Jan 13, 2024
ebb8726
fmt
jayzhan211 Jan 13, 2024
98a9cd1
add more queries
jayzhan211 Jan 16, 2024
07831fa
add group by query and rewrite evalute with state()
jayzhan211 Jan 17, 2024
62c8084
move evaluate back
jayzhan211 Jan 17, 2024
e3b65c8
upd test
jayzhan211 Jan 17, 2024
3f0e9a9
add row sort
jayzhan211 Jan 17, 2024
4bc483a
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb Jan 20, 2024
0475687
Update benchmarks/queries/clickbench/README.md
alamb Jan 20, 2024
a764e99
Rework set to avoid copies
alamb Jan 20, 2024
bde49c6
Merge branch 'bytes-distinctcount' of github.com:jayzhan211/arrow-dat…
alamb Jan 20, 2024
a101b62
Simplify offset construction
alamb Jan 20, 2024
0f2fa02
fmt
alamb Jan 20, 2024
489e130
Improve comments
alamb Jan 21, 2024
c39988a
Improve comments
alamb Jan 21, 2024
0e33b12
add fuzz test
jayzhan211 Jan 22, 2024
b3bcc68
Add support for LargeStringArray
alamb Jan 22, 2024
d7efcf6
Merge branch 'bytes-distinctcount' of github.com:jayzhan211/arrow-dat…
alamb Jan 22, 2024
a80b39c
refine fuzz test
alamb Jan 22, 2024
3e9289a
Add tests for size accounting
alamb Jan 22, 2024
7b9d067
Split into new module
alamb Jan 22, 2024
d405744
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb Jan 24, 2024
3a6a066
Remove use of Mutex
alamb Jan 24, 2024
f177aed
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb Jan 25, 2024
8640907
revert changes
alamb Jan 25, 2024
214ba5b
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb Jan 27, 2024
1e10b9c
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb Jan 28, 2024
f5e268d
Use reference rather than owned ArrayRef
alamb Jan 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions benchmarks/queries/clickbench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,25 @@ SELECT
FROM hits;
```

### Q1
alamb marked this conversation as resolved.
Show resolved Hide resolved
Models initial Data exploration, to understand some statistics of data.
Query to test distinct count for String. Three of them are all small string (length either 1 or 2).
alamb marked this conversation as resolved.
Show resolved Hide resolved

```sql
SELECT
COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage")
FROM hits;
```

### Q2
Models initial Data exploration, to understand some statistics of data.
Extend with `group by` from Q1

```sql
SELECT
"BrowserCountry", COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage")
FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10;
```



4 changes: 3 additions & 1 deletion benchmarks/queries/clickbench/extended.sql
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits;
SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits;
SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits;
SELECT "BrowserCountry", COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10;
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
chrono = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to use RawTableAlloc trait

datafusion-expr = { workspace = true }
half = { version = "2.1", default-features = false }
hashbrown = { version = "0.14", features = ["raw"] }
Expand Down
217 changes: 214 additions & 3 deletions datafusion/physical-expr/src/aggregate/count_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ use arrow_array::types::{
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::PrimitiveArray;
use arrow_array::{PrimitiveArray, StringArray};
use arrow_buffer::{BufferBuilder, MutableBuffer, OffsetBuffer};

use std::any::Any;
use std::cmp::Eq;
use std::fmt::Debug;
use std::hash::Hash;
use std::mem;
use std::sync::Arc;

use ahash::RandomState;
Expand All @@ -38,9 +40,10 @@ use std::collections::HashSet;
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{Result, ScalarValue};
use datafusion_execution::memory_pool::proxy::RawTableAllocExt;
use datafusion_expr::Accumulator;

type DistinctScalarValues = ScalarValue;
Expand Down Expand Up @@ -152,6 +155,8 @@ impl AggregateExpr for DistinctCount {
Float32 => float_distinct_count_accumulator!(Float32Type),
Float64 => float_distinct_count_accumulator!(Float64Type),

Utf8 => Ok(Box::new(StringDistinctCountAccumulator::new())),
alamb marked this conversation as resolved.
Show resolved Hide resolved

_ => Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
Expand Down Expand Up @@ -244,7 +249,7 @@ impl Accumulator for DistinctCountAccumulator {
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
for scalars in scalar_vec.into_iter() {
self.values.extend(scalars)
self.values.extend(scalars);
}
Ok(())
}
Expand Down Expand Up @@ -438,6 +443,212 @@ where
}
}

#[derive(Debug)]
struct StringDistinctCountAccumulator(SSOStringHashSet);
impl StringDistinctCountAccumulator {
fn new() -> Self {
Self(SSOStringHashSet::new())
}
}

impl Accumulator for StringDistinctCountAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
let arr = self.0.state();
let list = Arc::new(array_into_list_array(Arc::new(arr)));
Ok(vec![ScalarValue::List(list)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = as_string_array(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.0.insert(value);
}
});

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);

let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_string_array(&list)?;

list.iter().for_each(|value| {
if let Some(value) = value {
self.0.insert(value);
}
})
};
Ok(())
})
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
}

fn size(&self) -> usize {
// Size of accumulator
// + SSOStringHashSet size
std::mem::size_of_val(self) + self.0.size()
}
}

const SHORT_STRING_LEN: usize = mem::size_of::<usize>();

#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allow Copy since they are all native types

struct SSOStringHeader {
/// hash of the string value (used when resizing table)
hash: u64,
/// length of the string
len: usize,
/// short strings are stored inline, long strings are stored in the buffer
offset_or_inline: usize,
}

impl SSOStringHeader {
fn evaluate(&self, buffer: &[u8]) -> String {
if self.len <= SHORT_STRING_LEN {
self.offset_or_inline.to_string()
} else {
let offset = self.offset_or_inline;
// SAFETY: buffer is only appended to, and we correctly inserted values
unsafe {
std::str::from_utf8_unchecked(
buffer.get_unchecked(offset..offset + self.len),
)
}
.to_string()
}
}
}

// Short String Optimizated HashSet for String
// Equivalent to HashSet<String> but with better memory usage
#[derive(Default)]
struct SSOStringHashSet {
/// Core of the HashSet, it stores both the short and long string headers
header_set: HashSet<SSOStringHeader>,
/// Used to check if the long string already exists
long_string_map: hashbrown::raw::RawTable<SSOStringHeader>,
/// Total size of the map in bytes
map_size: usize,
/// Buffer containing all long strings
buffer: BufferBuilder<u8>,
/// The random state used to generate hashes
state: RandomState,
/// Used for capacity calculation, equivalent to the sum of all string lengths
size_hint: usize,
}

impl SSOStringHashSet {
fn new() -> Self {
Self::default()
}

fn insert(&mut self, value: &str) {
let value_len = value.len();
self.size_hint += value_len;
let value_bytes = value.as_bytes();

if value_len <= SHORT_STRING_LEN {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this significantly faster than hashing the bytes and using one RawTable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can still use the short string optimization without the second hash table. I hope to work on this idea this weekend

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the code to use a single table and that does appear to go signficantly faster

let inline = value_bytes
.iter()
.fold(0usize, |acc, &x| acc << 8 | x as usize);
let short_string_header = SSOStringHeader {
hash: 0, // no need for short string cases
len: value_len,
offset_or_inline: inline,
};
self.header_set.insert(short_string_header);
} else {
let hash = self.state.hash_one(value_bytes);

let entry = self.long_string_map.get_mut(hash, |header| {
// if hash matches, check if the bytes match
let offset = header.offset_or_inline;
let len = header.len;

// SAFETY: buffer is only appended to, and we correctly inserted values
let existing_value =
unsafe { self.buffer.as_slice().get_unchecked(offset..offset + len) };

value_bytes == existing_value
});

if entry.is_none() {
let offset = self.buffer.len();
self.buffer.append_slice(value_bytes);
let header = SSOStringHeader {
hash,
len: value_len,
offset_or_inline: offset,
};
self.long_string_map.insert_accounted(
alamb marked this conversation as resolved.
Show resolved Hide resolved
header,
|header| header.hash,
&mut self.map_size,
);
self.header_set.insert(header);
}
}
}

// Returns a StringArray with the current state of the set
fn state(&self) -> StringArray {
let mut offsets = Vec::with_capacity(self.size_hint + 1);
offsets.push(0);

let mut values = MutableBuffer::new(0);
let buffer = self.buffer.as_slice();

for header in self.header_set.iter() {
let s = header.evaluate(buffer);
values.extend_from_slice(s.as_bytes());
offsets.push(values.len() as i32);
}

let value_offsets = OffsetBuffer::<i32>::new(offsets.into());
StringArray::new(value_offsets, values.into(), None)
}

fn len(&self) -> usize {
self.header_set.len()
}

fn size(&self) -> usize {
self.header_set.len() * mem::size_of::<SSOStringHeader>()
+ self.map_size
+ self.buffer.len()
}
}

impl Debug for SSOStringHashSet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SSOStringHashSet")
.field("header_set", &self.header_set)
// TODO: Print long_string_map
.field("map_size", &self.map_size)
.field("buffer", &self.buffer)
.field("state", &self.state)
.finish()
}
}
#[cfg(test)]
mod tests {
use crate::expressions::NoOp;
Expand Down
30 changes: 30 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -3260,3 +3260,33 @@ query I
select count(*) from (select count(*) a, count(*) b from (select 1));
----
1

# Distinct Count for string

# UTF8 string matters for string to &[u8] conversion, add it to prevent regression
statement ok
create table distinct_count_string_table as values
(1, 'a', 'longstringtest_a', '台灣'),
(2, 'b', 'longstringtest_b1', '日本'),
(2, 'b', 'longstringtest_b2', '中國'),
(3, 'c', 'longstringtest_c1', '美國'),
(3, 'c', 'longstringtest_c2', '歐洲'),
(3, 'c', 'longstringtest_c3', '韓國')
;

# run through update_batch
query IIII
select count(distinct column1), count(distinct column2), count(distinct column3), count(distinct column4) from distinct_count_string_table;
----
3 3 6 6

# run through merge_batch
query IIII rowsort
select count(distinct column1), count(distinct column2), count(distinct column3), count(distinct column4) from distinct_count_string_table group by column1;
----
1 1 1 1
1 1 2 2
1 1 3 3

statement ok
drop table distinct_count_string_table;
3 changes: 3 additions & 0 deletions datafusion/sqllogictest/test_files/clickbench.slt
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,6 @@ SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hit
query PI
SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000;
----

query
drop table hits;