Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] add min_hash alternate hashers #3052

Merged
merged 37 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b689c63
[TEST] test_minhash_exact_values
andrewgazelka Oct 28, 2024
9139530
refactor hashing code
andrewgazelka Oct 15, 2024
b41105f
improve window code
andrewgazelka Oct 15, 2024
c9b7df8
clean up naming
andrewgazelka Oct 15, 2024
5926f8c
add utf-8 windowed tests
andrewgazelka Oct 15, 2024
d45f69e
add windowed words ext trait
andrewgazelka Oct 15, 2024
3bbe3b4
add more tests
andrewgazelka Oct 15, 2024
735e80c
fix bench clippy
andrewgazelka Oct 16, 2024
5cd1a39
add many hashers
andrewgazelka Oct 16, 2024
cac823c
fmt
andrewgazelka Oct 16, 2024
7d4b11a
add visual diagram
andrewgazelka Oct 16, 2024
96a578f
improve vectorization of minhash
andrewgazelka Oct 16, 2024
bacab6b
fix comp error
andrewgazelka Oct 16, 2024
2389c8c
stash
andrewgazelka Oct 17, 2024
89bf41d
fix ci
andrewgazelka Oct 22, 2024
e3f77a3
update tests
andrewgazelka Oct 22, 2024
3bca00d
fix some things in tests
andrewgazelka Oct 23, 2024
cfaf9b9
fix some things?
andrewgazelka Oct 23, 2024
8bb295c
fix
Oct 23, 2024
f6c400e
fix default py features
andrewgazelka Oct 23, 2024
32bd648
add tests
andrewgazelka Oct 24, 2024
e7c9520
not deterministic???
andrewgazelka Oct 24, 2024
aa53ba1
fix many tests
andrewgazelka Oct 24, 2024
49f25ea
more dry code in tests
andrewgazelka Oct 24, 2024
62e0a65
update edge case
andrewgazelka Oct 24, 2024
f48bea8
remove commented out code
andrewgazelka Oct 24, 2024
489dc81
add basic benching
andrewgazelka Oct 24, 2024
d47281b
improve windowed perf
andrewgazelka Oct 24, 2024
7e1eb2e
improve perf
andrewgazelka Oct 24, 2024
1abcc82
move to literal syntax
andrewgazelka Oct 25, 2024
9457144
remove unused deps
andrewgazelka Oct 25, 2024
23daae6
fix test
andrewgazelka Oct 25, 2024
3ac2ffc
update
andrewgazelka Oct 25, 2024
2d21121
fix hash name
andrewgazelka Oct 25, 2024
a945bba
add tests
andrewgazelka Oct 28, 2024
296d391
Merge remote-tracking branch 'origin/main' into andrew/hash
andrewgazelka Oct 30, 2024
f2d9ad4
add cory tests
andrewgazelka Oct 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 294 additions & 7 deletions Cargo.lock

Large diffs are not rendered by default.

63 changes: 38 additions & 25 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ daft-csv = {path = "src/daft-csv", default-features = false}
daft-dsl = {path = "src/daft-dsl", default-features = false}
daft-functions = {path = "src/daft-functions", default-features = false}
daft-functions-json = {path = "src/daft-functions-json", default-features = false}
daft-hash = {path = "src/daft-hash", default-features = false}
daft-image = {path = "src/daft-image", default-features = false}
daft-io = {path = "src/daft-io", default-features = false}
daft-json = {path = "src/daft-json", default-features = false}
Expand All @@ -37,29 +38,29 @@ sysinfo = {workspace = true}
[features]
# maturin will turn this on
python = [
"dep:pyo3",
"dep:pyo3-log",
"common-daft-config/python",
"common-display/python",
"common-resource-request/python",
"common-system-info/python",
"daft-core/python",
"daft-csv/python",
"daft-dsl/python",
"daft-local-execution/python",
"daft-io/python",
"daft-functions-json/python",
"daft-functions/python",
"daft-image/python",
"daft-io/python",
"daft-json/python",
"daft-local-execution/python",
"daft-micropartition/python",
"daft-parquet/python",
"daft-plan/python",
"daft-scan/python",
"daft-scheduler/python",
"daft-stats/python",
"daft-sql/python",
"daft-stats/python",
"daft-table/python",
"daft-functions/python",
"daft-functions-json/python",
"common-daft-config/python",
"common-system-info/python",
"common-display/python",
"common-resource-request/python"
"dep:pyo3",
"dep:pyo3-log"
]

[lib]
Expand Down Expand Up @@ -113,35 +114,38 @@ tikv-jemallocator = {version = "0.5.4", features = [
[workspace]
members = [
"src/arrow2",
"src/parquet2",
"src/common/daft-config",
"src/common/display",
"src/common/error",
"src/common/io-config",
"src/common/treenode",
"src/common/daft-config",
"src/common/system-info",
"src/common/treenode",
"src/daft-core",
"src/daft-local-execution",
"src/daft-io",
"src/daft-image",
"src/daft-parquet",
"src/daft-csv",
"src/daft-json",
"src/daft-dsl",
"src/daft-table",
"src/daft-plan",
"src/daft-physical-plan",
"src/daft-functions",
"src/daft-functions-json",
"src/daft-hash",
"src/daft-image",
"src/daft-io",
"src/daft-json",
"src/daft-local-execution",
"src/daft-micropartition",
"src/daft-parquet",
"src/daft-physical-plan",
"src/daft-plan",
"src/daft-scan",
"src/daft-scheduler",
"src/daft-sketch",
"src/daft-functions",
"src/daft-functions-json",
"src/daft-sql",
"src/hyperloglog"
"src/daft-table",
"src/hyperloglog",
"src/parquet2"
]

[workspace.dependencies]
ahash = "0.8.11"
approx = "0.5.1"
async-compat = "0.2.3"
async-compression = {version = "0.4.12", features = [
"tokio",
Expand All @@ -154,7 +158,10 @@ bytes = "1.6.0"
chrono = "0.4.38"
chrono-tz = "0.8.4"
comfy-table = "7.1.1"
common-error = {path = "src/common/error", default-features = false}
daft-hash = {path = "src/daft-hash"}
derivative = "2.2.0"
divan = "0.1.14"
dyn-clone = "1"
futures = "0.3.30"
html-escape = "0.2.13"
Expand All @@ -164,20 +171,25 @@ jaq-core = "1.2.0"
jaq-interpret = "1.2.0"
jaq-parse = "1.0.0"
jaq-std = "1.2.0"
mur3 = "0.1.0"
num-derive = "0.3.3"
num-traits = "0.2"
once_cell = "1.19.0"
path_macro = "1.0.0"
pretty_assertions = "1.4.0"
proptest = "1.5.0"
rand = "^0.8"
rayon = "1.10.0"
regex = "1.10.4"
rstest = "0.18.2"
rustc-hash = "2.0.0"
serde_json = "1.0.116"
sha1 = "0.11.0-pre.4"
sketches-ddsketch = {version = "0.2.2", features = ["use_serde"]}
snafu = {version = "0.7.4", features = ["futures"]}
sqlparser = "0.51.0"
sysinfo = "0.30.12"
tango-bench = "0.6.0"
test-log = "0.2.16"
thiserror = "1.0.63"
tiktoken-rs = "0.5.9"
Expand All @@ -195,6 +207,7 @@ tokio-stream = {version = "0.1.14", features = ["fs", "io-util", "time"]}
tokio-util = "0.7.11"
tracing = "0.1"
url = "2.4.0"
xxhash-rust = "0.8.12"

[workspace.dependencies.arrow2]
path = "src/arrow2"
Expand Down
9 changes: 8 additions & 1 deletion daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,7 @@ def minhash(
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmurhash3", "xxhash", "sha1"] = "murmurhash3",
) -> PyExpr: ...

# -----
Expand Down Expand Up @@ -1347,7 +1348,13 @@ class PySeries:
def sort(self, descending: bool) -> PySeries: ...
def argsort(self, descending: bool) -> PySeries: ...
def hash(self, seed: PySeries | None = None) -> PySeries: ...
def minhash(self, num_hashes: int, ngram_size: int, seed: int = 1) -> PySeries: ...
def minhash(
self,
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmurhash3", "xxhash", "sha1"] = "murmurhash3",
) -> PySeries: ...
def __invert__(self) -> PySeries: ...
def count(self, mode: CountMode) -> PySeries: ...
def sum(self) -> PySeries: ...
Expand Down
9 changes: 7 additions & 2 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ def minhash(
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmurhash3", "xxhash", "sha1"] = "murmurhash3",
) -> Expression:
"""
Runs the MinHash algorithm on the series.
Expand All @@ -1204,19 +1205,23 @@ def minhash(
repeating with `num_hashes` permutations. Returns as a list of 32-bit unsigned integers.

Tokens for the ngrams are delimited by spaces.
MurmurHash is used for the initial hash.
The strings are not normalized or pre-processed, so it is recommended
to normalize the strings yourself.

Args:
num_hashes: The number of hash permutations to compute.
ngram_size: The number of tokens in each shingle/ngram.
seed (optional): Seed used for generating permutations and the initial string hashes. Defaults to 1.
hash_function (optional): Hash function to use for initial string hashing. One of "murmurhash3", "xxhash", or "sha1". Defaults to "murmurhash3".

"""
assert isinstance(num_hashes, int)
assert isinstance(ngram_size, int)
assert isinstance(seed, int)
return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed))
assert isinstance(hash_function, str)
assert hash_function in ["murmurhash3", "xxhash", "sha1"], f"Hash function {hash_function} not found"

return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed, hash_function))

def name(self) -> builtins.str:
return self._expr.name()
Expand Down
13 changes: 11 additions & 2 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmurhash3", "xxhash", "sha1"] = "murmurhash3",
) -> Series:
"""
Runs the MinHash algorithm on the series.
Expand All @@ -582,15 +583,23 @@
num_hashes: The number of hash permutations to compute.
ngram_size: The number of tokens in each shingle/ngram.
seed (optional): Seed used for generating permutations and the initial string hashes. Defaults to 1.
hash_function (optional): Hash function to use for initial string hashing. One of "murmur3", "xxhash", or "sha1". Defaults to "murmur3".
"""
if not isinstance(num_hashes, int):
raise ValueError(f"expected an integer for num_hashes but got {type(num_hashes)}")
if not isinstance(ngram_size, int):
raise ValueError(f"expected an integer for ngram_size but got {type(ngram_size)}")
if seed is not None and not isinstance(seed, int):
raise ValueError(f"expected an integer or None for seed but got {type(seed)}")

return Series._from_pyseries(self._series.minhash(num_hashes, ngram_size, seed))
if not isinstance(hash_function, str):
raise ValueError(f"expected str for hash_function but got {type(hash_function)}")

Check warning on line 595 in daft/series.py

View check run for this annotation

Codecov / codecov/patch

daft/series.py#L595

Added line #L595 was not covered by tests
assert hash_function in [
"murmurhash3",
"xxhash",
"sha1",
], f"hash_function must be one of 'murmurhash3', 'xxhash', 'sha1', got {hash_function}"

return Series._from_pyseries(self._series.minhash(num_hashes, ngram_size, seed, hash_function))

def _to_str_values(self) -> Series:
return Series._from_pyseries(self._series.to_str_values())
Expand Down
11 changes: 6 additions & 5 deletions src/daft-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ common-display = {path = "../common/display", default-features = false}
common-error = {path = "../common/error", default-features = false}
common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"}
common-py-serde = {path = "../common/py-serde", default-features = false}
daft-hash = {workspace = true}
daft-minhash = {path = "../daft-minhash", default-features = false}
daft-schema = {path = "../daft-schema", default-features = false}
daft-sketch = {path = "../daft-sketch", default-features = false}
Expand All @@ -50,17 +51,17 @@ optional = true
version = "0.21.0"

[dependencies.xxhash-rust]
features = ["xxh3", "const_xxh3"]
features = ["xxh3", "const_xxh3", "xxh64"]
version = "0.8.5"

[features]
python = [
"dep:pyo3",
"dep:numpy",
"common-arrow-ffi/python",
"common-error/python",
"common-py-serde/python",
"common-arrow-ffi/python",
"daft-schema/python"
"daft-schema/python",
"dep:numpy",
"dep:pyo3"
]

[lints]
Expand Down
79 changes: 57 additions & 22 deletions src/daft-core/src/array/ops/minhash.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::iter::repeat_with;
use std::{collections::VecDeque, hash::BuildHasher, iter::repeat_with};

use arrow2::array::{MutableArray, MutablePrimitiveArray, PrimitiveArray};
use common_error::{DaftError, DaftResult};
Expand All @@ -14,7 +14,13 @@ use crate::{
impl DaftMinHash for Utf8Array {
type Output = DaftResult<FixedSizeListArray>;

fn minhash(&self, num_hashes: usize, ngram_size: usize, seed: u32) -> Self::Output {
fn minhash(
&self,
num_hashes: usize,
ngram_size: usize,
seed: u32,
hasher: &impl BuildHasher,
) -> Self::Output {
if num_hashes == 0 {
return Err(DaftError::ValueError(
"Number of hashes must be nonzero".into(),
Expand All @@ -24,42 +30,71 @@ impl DaftMinHash for Utf8Array {
return Err(DaftError::ValueError("Ngram size must be nonzero".into()));
}

// generate permutations
// Generate coefficients for MinHash permutation function: (a * x + b) % p
//
// The MinHash algorithm uses a hash function of the form (a * x + b) % p,
// where 'a' and 'b' are permutation coefficients, 'x' is the input hash,
// and 'p' is typically a large prime number.
//
// 1. perm_a (coefficient 'a'):
// - Starts from 1 to ensure 'a' is never zero
// - A non-zero 'a' is crucial for maintaining the bijective property of the permutation
//
// Example of how bijectivity fails if a = 0:
// Let p = 7 (prime number)
// If a = 0, b = 3, the function becomes: (0 * x + 3) % 7 = 3
// This always outputs 3, regardless of the input x, losing the bijective property
//
// 2. perm_b (coefficient 'b'):
// - Range: 0 to (i32::MAX as u64) - 1
// - Can start from 0 as 'b' can be any value without affecting the permutation property
//
// This approach ensures valid and uniformly distributed hash values, which is
// essential for accurate set similarity estimation in MinHash.
let mut rng = fastrand::Rng::with_seed(seed as u64);
let perm_a = repeat_with(|| rng.u64(1..(i32::MAX as u64))).take(num_hashes);
let perm_a_simd = load_simd(perm_a, num_hashes);
let perm_b = repeat_with(|| rng.u64(0..(i32::MAX as u64))).take(num_hashes);
let perm_b_simd = load_simd(perm_b, num_hashes);

let self_arrow = self.as_arrow();
let internal_arrow_representation = self.as_arrow();
let mut output: MutablePrimitiveArray<u32> =
MutablePrimitiveArray::with_capacity(num_hashes * self.len());
for maybe_s in self_arrow {
if let Some(s) = maybe_s {
let minhash_res = daft_minhash::minhash(
s,
(&perm_a_simd, &perm_b_simd),
num_hashes,
ngram_size,
seed,
)?;
output.extend(minhash_res.into_iter().map(Some));
} else {

let mut alloc = VecDeque::new();

for elem in internal_arrow_representation {
let Some(elem) = elem else {
for _ in 0..num_hashes {
output.push_null();
}
}
continue;
};

let minhash_res = daft_minhash::minhash_in(
elem,
(&perm_a_simd, &perm_b_simd),
num_hashes,
ngram_size,
hasher,
&mut alloc,
)?;

output.extend(minhash_res.into_iter().map(Some));
}
let output_immut: PrimitiveArray<u32> = output.into();

let immutable_output: PrimitiveArray<u32> = output.into();
let output_series = Series::from_arrow(
Field::new(self.name(), DataType::UInt32).into(),
Box::new(output_immut),
Box::new(immutable_output),
)?;
let field = Field::new(
self.name(),
DataType::FixedSizeList(Box::new(DataType::UInt32), num_hashes),
);

Ok(FixedSizeListArray::new(
Field::new(
self.name(),
DataType::FixedSizeList(Box::new(DataType::UInt32), num_hashes),
),
field,
output_series,
self.validity().cloned(),
))
Expand Down
Loading
Loading