From d2215bc1054395d6e097ba3df23594461bab5d2f Mon Sep 17 00:00:00 2001 From: Arham Khan Date: Sat, 2 Mar 2024 20:42:40 +0000 Subject: [PATCH 1/6] add get minhash from lshforest --- datasketch/lshforest.py | 22 ++++++++++++++++++++++ test/test_lshforest.py | 8 ++++++++ 2 files changed, 30 insertions(+) diff --git a/datasketch/lshforest.py b/datasketch/lshforest.py index 9f3455ba..33cb4edd 100644 --- a/datasketch/lshforest.py +++ b/datasketch/lshforest.py @@ -1,5 +1,6 @@ from collections import defaultdict from typing import Hashable, List +import numpy as np from datasketch.minhash import MinHash @@ -128,6 +129,27 @@ def query(self, minhash: MinHash, k: int) -> List[Hashable]: r -= 1 return list(results) + def get_minhash_from_key(self, key: Hashable) -> MinHash: + """ + Returns the MinHash value that corresponds to the given key in the LSHForest, if it exists. This is useful for when we want to manually check the Jaccard Similarity for the top-k results from a query. + + Args: + key (Hashable): The key whose MinHash we want to retrieve. + + Returns: + MinHash: The corresponding MinHash value for the provided key. + """ + byteslist = self.keys.get(key, None) + if byteslist is None: + raise KeyError("The provided key does not exist in the LSHForest: {key}") + hashvalues = np.array([], dtype=np.uint64) + for item in byteslist: + # unswap the bytes, as their representation is flipped during storage + hv_segment = np.frombuffer(item, dtype=np.uint64).byteswap() + hashvalues = np.append(hashvalues, hv_segment) + minhash = MinHash(hashvalues=hashvalues) + return minhash + def _binary_search(self, n, func): """ https://golang.org/src/sort/search.go?s=2247:2287#L49 diff --git a/test/test_lshforest.py b/test/test_lshforest.py index 77e7bf43..f62f1615 100644 --- a/test/test_lshforest.py +++ b/test/test_lshforest.py @@ -62,6 +62,14 @@ def test_query(self): results = forest.query(data[key], 10) self.assertIn(key, results) + def test_get_minhash(self): + forest, data = self._setup() + for key in data: + minhash_ori = data[key] + minhash_retrieved = forest.get_minhash_from_key(key) + self.assertEqual(minhash_retrieved.jaccard(minhash_ori), 1.0) + + def test_pickle(self): forest, _ = self._setup() forest2 = pickle.loads(pickle.dumps(forest)) From dd46f545d1114be000f043782addabc6d6a12a66 Mon Sep 17 00:00:00 2001 From: Arham Khan Date: Sat, 2 Mar 2024 20:46:20 +0000 Subject: [PATCH 2/6] format --- datasketch/lshforest.py | 6 ++++-- test/test_lshforest.py | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/datasketch/lshforest.py b/datasketch/lshforest.py index 33cb4edd..492c0614 100644 --- a/datasketch/lshforest.py +++ b/datasketch/lshforest.py @@ -131,8 +131,10 @@ def query(self, minhash: MinHash, k: int) -> List[Hashable]: def get_minhash_from_key(self, key: Hashable) -> MinHash: """ - Returns the MinHash value that corresponds to the given key in the LSHForest, if it exists. This is useful for when we want to manually check the Jaccard Similarity for the top-k results from a query. - + Returns the MinHash value that corresponds to the given key in the LSHForest, + if it exists. This is useful for when we want to manually check the + Jaccard Similarity for the top-k results from a query. + Args: key (Hashable): The key whose MinHash we want to retrieve. diff --git a/test/test_lshforest.py b/test/test_lshforest.py index f62f1615..fd03faed 100644 --- a/test/test_lshforest.py +++ b/test/test_lshforest.py @@ -68,7 +68,6 @@ def test_get_minhash(self): minhash_ori = data[key] minhash_retrieved = forest.get_minhash_from_key(key) self.assertEqual(minhash_retrieved.jaccard(minhash_ori), 1.0) - def test_pickle(self): forest, _ = self._setup() From 9c94b60fb149c15593d3cb14966d3db964dc65d6 Mon Sep 17 00:00:00 2001 From: Arham Khan Date: Sat, 2 Mar 2024 20:52:23 +0000 Subject: [PATCH 3/6] fix format string --- datasketch/lshforest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasketch/lshforest.py b/datasketch/lshforest.py index 492c0614..7e8b269b 100644 --- a/datasketch/lshforest.py +++ b/datasketch/lshforest.py @@ -143,7 +143,7 @@ def get_minhash_from_key(self, key: Hashable) -> MinHash: """ byteslist = self.keys.get(key, None) if byteslist is None: - raise KeyError("The provided key does not exist in the LSHForest: {key}") + raise KeyError(f"The provided key does not exist in the LSHForest: {key}") hashvalues = np.array([], dtype=np.uint64) for item in byteslist: # unswap the bytes, as their representation is flipped during storage From 31adb6a8f3c8d98520a66d87a596df0ea71eb49e Mon Sep 17 00:00:00 2001 From: Arham Khan Date: Mon, 4 Mar 2024 00:30:57 +0000 Subject: [PATCH 4/6] return hashvalues instead of MinHash --- datasketch/lshforest.py | 15 +++++++-------- test/test_lshforest.py | 5 +++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/datasketch/lshforest.py b/datasketch/lshforest.py index 7e8b269b..cbd58b0c 100644 --- a/datasketch/lshforest.py +++ b/datasketch/lshforest.py @@ -129,17 +129,17 @@ def query(self, minhash: MinHash, k: int) -> List[Hashable]: r -= 1 return list(results) - def get_minhash_from_key(self, key: Hashable) -> MinHash: + def get_minhash_hashvalues(self, key: Hashable) -> np.ndarray: """ - Returns the MinHash value that corresponds to the given key in the LSHForest, - if it exists. This is useful for when we want to manually check the - Jaccard Similarity for the top-k results from a query. + Returns the hashvalues from the MinHash object that corresponds to the given key in the LSHForest, + if it exists. This is useful for when we want to reconstruct the original MinHash + object to manually check the Jaccard Similarity for the top-k results from a query. Args: - key (Hashable): The key whose MinHash we want to retrieve. + key (Hashable): The key whose MinHash hashvalues we want to retrieve. Returns: - MinHash: The corresponding MinHash value for the provided key. + hashvalues: The hashvalues for the MinHash object corresponding to the given key. """ byteslist = self.keys.get(key, None) if byteslist is None: @@ -149,8 +149,7 @@ def get_minhash_from_key(self, key: Hashable) -> MinHash: # unswap the bytes, as their representation is flipped during storage hv_segment = np.frombuffer(item, dtype=np.uint64).byteswap() hashvalues = np.append(hashvalues, hv_segment) - minhash = MinHash(hashvalues=hashvalues) - return minhash + return hashvalues def _binary_search(self, n, func): """ diff --git a/test/test_lshforest.py b/test/test_lshforest.py index fd03faed..2f33ca4c 100644 --- a/test/test_lshforest.py +++ b/test/test_lshforest.py @@ -62,11 +62,12 @@ def test_query(self): results = forest.query(data[key], 10) self.assertIn(key, results) - def test_get_minhash(self): + def get_minhash_hashvalues(self): forest, data = self._setup() for key in data: minhash_ori = data[key] - minhash_retrieved = forest.get_minhash_from_key(key) + hashvalues = forest.get_minhash_from_key(key) + minhash_retrieved = MinHash(hashvalues=hashvalues) self.assertEqual(minhash_retrieved.jaccard(minhash_ori), 1.0) def test_pickle(self): From 8820f79e95dd760cebb6eb88ac1bc9739ba7f529 Mon Sep 17 00:00:00 2001 From: Arham Khan Date: Sun, 10 Mar 2024 00:31:07 +0000 Subject: [PATCH 5/6] preallocate hashvalue buffer --- datasketch/lshforest.py | 8 +++++--- test/test_lshforest.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/datasketch/lshforest.py b/datasketch/lshforest.py index cbd58b0c..a02569d8 100644 --- a/datasketch/lshforest.py +++ b/datasketch/lshforest.py @@ -144,11 +144,13 @@ def get_minhash_hashvalues(self, key: Hashable) -> np.ndarray: byteslist = self.keys.get(key, None) if byteslist is None: raise KeyError(f"The provided key does not exist in the LSHForest: {key}") - hashvalues = np.array([], dtype=np.uint64) - for item in byteslist: + hashvalue_byte_size = len(byteslist[0])//8 + hashvalues = np.empty(len(byteslist)*hashvalue_byte_size, dtype=np.uint64) + for index, item in enumerate(byteslist): # unswap the bytes, as their representation is flipped during storage hv_segment = np.frombuffer(item, dtype=np.uint64).byteswap() - hashvalues = np.append(hashvalues, hv_segment) + curr_index = index*hashvalue_byte_size + hashvalues[curr_index:curr_index+hashvalue_byte_size] = hv_segment return hashvalues def _binary_search(self, n, func): diff --git a/test/test_lshforest.py b/test/test_lshforest.py index 2f33ca4c..dee5a0e5 100644 --- a/test/test_lshforest.py +++ b/test/test_lshforest.py @@ -62,11 +62,11 @@ def test_query(self): results = forest.query(data[key], 10) self.assertIn(key, results) - def get_minhash_hashvalues(self): + def test_get_minhash_hashvalues(self): forest, data = self._setup() for key in data: minhash_ori = data[key] - hashvalues = forest.get_minhash_from_key(key) + hashvalues = forest.get_minhash_hashvalues(key) minhash_retrieved = MinHash(hashvalues=hashvalues) self.assertEqual(minhash_retrieved.jaccard(minhash_ori), 1.0) From f980603bef7fdd124423a8c9178185a0042185f8 Mon Sep 17 00:00:00 2001 From: Arham Khan Date: Sun, 10 Mar 2024 21:58:04 +0000 Subject: [PATCH 6/6] add direct hashvalue check to test --- test/test_lshforest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_lshforest.py b/test/test_lshforest.py index dee5a0e5..400a9af8 100644 --- a/test/test_lshforest.py +++ b/test/test_lshforest.py @@ -68,7 +68,11 @@ def test_get_minhash_hashvalues(self): minhash_ori = data[key] hashvalues = forest.get_minhash_hashvalues(key) minhash_retrieved = MinHash(hashvalues=hashvalues) + retrieved_hashvalues = minhash_retrieved.hashvalues + self.assertEqual(len(hashvalues), len(retrieved_hashvalues)) self.assertEqual(minhash_retrieved.jaccard(minhash_ori), 1.0) + for i in range(len(retrieved_hashvalues)): + self.assertEqual(hashvalues[i], retrieved_hashvalues[i]) def test_pickle(self): forest, _ = self._setup()