From 9827f2f36e17f928259da65c028e8030179fa1a3 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 6 Sep 2023 01:55:02 -0700 Subject: [PATCH] HNSW Dict interface (#221) --- README.rst | 3 +- datasketch/hnsw.py | 79 +++++++++++++++++++++++++++++++--------------- test/test_hnsw.py | 23 ++++++++++++++ 3 files changed, 78 insertions(+), 27 deletions(-) diff --git a/README.rst b/README.rst index 284a1eec..eaa1301d 100644 --- a/README.rst +++ b/README.rst @@ -37,7 +37,7 @@ sub-linear query time: +---------------------------+-----------------------------+------------------------+ | `MinHash LSH Ensemble`_ | MinHash | Containment Threshold | +---------------------------+-----------------------------+------------------------+ -| HNSW | Customizable | Metric Distances | +| `HNSW`_ | Customizable | Metric Distances | +---------------------------+-----------------------------+------------------------+ datasketch must be used with Python 3.7 or above, NumPy 1.11 or above, and Scipy. @@ -77,3 +77,4 @@ To install with Cassandra dependency: .. _`MinHash LSH Forest`: https://ekzhu.github.io/datasketch/lshforest.html .. _`MinHash LSH Ensemble`: https://ekzhu.github.io/datasketch/lshensemble.html .. _`Minhash LSH at Scale`: http://ekzhu.github.io/datasketch/lsh.html#minhash-lsh-at-scale +.. _`HNSW`: https://ekzhu.github.io/datasketch/documentation.html#hnsw diff --git a/datasketch/hnsw.py b/datasketch/hnsw.py index be27ecf8..f1e28e9f 100644 --- a/datasketch/hnsw.py +++ b/datasketch/hnsw.py @@ -1,9 +1,39 @@ import heapq -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple import numpy as np +class _Layer(object): + """A graph layer in the HNSW index. This is a dictionary-like object + that maps a key to a dictionary of neighbors. + + Args: + key (Any): The first key to insert into the graph. + """ + + def __init__(self, key: Any) -> None: + # self._graph[key] contains a {j: dist} dictionary, + # where j is a neighbor of key and dist is distance. + self._graph: Dict[Any, Dict[Any, float]] = {key: {}} + # self._reverse_edges[key] contains a set of neighbors of key. + self._reverse_edges: Dict[Any, Set] = {} + + def __contains__(self, key: Any) -> bool: + return key in self._graph + + def __getitem__(self, key: Any) -> Dict[Any, float]: + return self._graph[key] + + def __setitem__(self, key: Any, value: Dict[Any, float]) -> None: + old_neighbors = self._graph.get(key, {}) + self._graph[key] = value + for neighbor in old_neighbors: + self._reverse_edges[neighbor].discard(key) + for neighbor in value: + self._reverse_edges.setdefault(neighbor, set()).add(key) + + class _Layer(object): """A graph layer in the HNSW index. This is a dictionary-like object that maps a key to a dictionary of neighbors. @@ -39,7 +69,7 @@ class HNSW(object): nearest neighbor search. This implementation is based on the paper "Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs" by Yu. A. Malkov, D. A. Yashunin (2016), - `_. + ``_. Args: distance_func: A function that takes two vectors and returns a float @@ -59,7 +89,7 @@ class HNSW(object): import hnsw import numpy as np data = np.random.random_sample((1000, 10)) - index = hnsw.HNSW(distance_func=lambda x, y: np.linalg.norm(x - y), m=5, efConstruction=200) + index = hnsw.HNSW(distance_func=lambda x, y: np.linalg.norm(x - y)) for i, d in enumerate(data): index.insert(i, d) index.query(data[0], k=10) @@ -69,7 +99,7 @@ class HNSW(object): def __init__( self, distance_func: Callable[[np.ndarray, np.ndarray], float], - m: int = 5, + m: int = 16, ef_construction: int = 200, m0: Optional[int] = None, seed: Optional[int] = None, @@ -84,24 +114,34 @@ def __init__( self._entry_point = None self._random = np.random.RandomState(seed) - def __len__(self): + def __len__(self) -> int: return len(self._data) def __contains__(self, key) -> bool: return key in self._data - def __repr__(self): - return ( - f"HNSW({self._distance_func}, m={self._m}, " - f"ef_construction={self._ef_construction}, m0={self._m0}, " - f"num_points={len(self._data)}, num_levels={len(self._graphs)}), " - f"random_state={self._random.get_state()}" - ) - def __getitem__(self, key) -> np.ndarray: """Get the point associated with the key.""" return self._data[key] + def items(self) -> Iterable[Tuple[Any, np.ndarray]]: + """Get an iterator over (key, point) pairs in the index.""" + return self._data.items() + + def keys(self) -> Iterable[Any]: + """Get an iterator over keys in the index.""" + return self._data.keys() + + def values(self) -> Iterable[np.ndarray]: + """Get an iterator over points in the index.""" + return self._data.values() + + def clear(self) -> None: + """Clear the index of all data points.""" + self._data = {} + self._graphs = [] + self._entry_point = None + def insert( self, key: Any, @@ -172,19 +212,6 @@ def insert( # We set the entry point for each new level to be the new node. self._entry_point = key - def clear(self) -> None: - """Clear the index of all data points. - - Args: - None - - Returns: - None - """ - self._data = {} - self._graphs = [] - self._entry_point = None - def _update(self, key: Any, new_point: np.ndarray, ef: int) -> None: """Update the point associated with the key in the index. diff --git a/test/test_hnsw.py b/test/test_hnsw.py index f8052790..b2639994 100644 --- a/test/test_hnsw.py +++ b/test/test_hnsw.py @@ -5,6 +5,10 @@ from datasketch.hnsw import HNSW +def l2_distance(x, y): + return np.linalg.norm(x - y) + + class TestHNSW(unittest.TestCase): def test_search_l2(self): data = np.random.rand(100, 10) @@ -91,3 +95,22 @@ def test_update_jaccard(self): jaccard_func(hnsw[results[j][0]], data[i]), jaccard_func(hnsw[results[j + 1][0]], data[i]), ) + + def test_pickle(self): + data = np.random.rand(100, 10) + hnsw = HNSW( + distance_func=l2_distance, + m=16, + ef_construction=100, + ) + for i in range(len(data)): + hnsw.insert(i, data[i]) + + import pickle + + hnsw2 = pickle.loads(pickle.dumps(hnsw)) + + for i in range(len(data)): + results1 = hnsw.query(data[i], 10) + results2 = hnsw2.query(data[i], 10) + self.assertEqual(results1, results2)