Skip to content

Commit

Permalink
HNSW Dict interface (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored Sep 6, 2023
1 parent 4bd6800 commit 9827f2f
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 27 deletions.
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ sub-linear query time:
+---------------------------+-----------------------------+------------------------+
| `MinHash LSH Ensemble`_ | MinHash | Containment Threshold |
+---------------------------+-----------------------------+------------------------+
| HNSW | Customizable | Metric Distances |
| `HNSW`_ | Customizable | Metric Distances |
+---------------------------+-----------------------------+------------------------+

datasketch must be used with Python 3.7 or above, NumPy 1.11 or above, and Scipy.
Expand Down Expand Up @@ -77,3 +77,4 @@ To install with Cassandra dependency:
.. _`MinHash LSH Forest`: https://ekzhu.github.io/datasketch/lshforest.html
.. _`MinHash LSH Ensemble`: https://ekzhu.github.io/datasketch/lshensemble.html
.. _`Minhash LSH at Scale`: http://ekzhu.github.io/datasketch/lsh.html#minhash-lsh-at-scale
.. _`HNSW`: https://ekzhu.github.io/datasketch/documentation.html#hnsw
79 changes: 53 additions & 26 deletions datasketch/hnsw.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
import heapq
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple

import numpy as np


class _Layer(object):
"""A graph layer in the HNSW index. This is a dictionary-like object
that maps a key to a dictionary of neighbors.
Args:
key (Any): The first key to insert into the graph.
"""

def __init__(self, key: Any) -> None:
# self._graph[key] contains a {j: dist} dictionary,
# where j is a neighbor of key and dist is distance.
self._graph: Dict[Any, Dict[Any, float]] = {key: {}}
# self._reverse_edges[key] contains a set of neighbors of key.
self._reverse_edges: Dict[Any, Set] = {}

def __contains__(self, key: Any) -> bool:
return key in self._graph

def __getitem__(self, key: Any) -> Dict[Any, float]:
return self._graph[key]

def __setitem__(self, key: Any, value: Dict[Any, float]) -> None:
old_neighbors = self._graph.get(key, {})
self._graph[key] = value
for neighbor in old_neighbors:
self._reverse_edges[neighbor].discard(key)
for neighbor in value:
self._reverse_edges.setdefault(neighbor, set()).add(key)


class _Layer(object):
"""A graph layer in the HNSW index. This is a dictionary-like object
that maps a key to a dictionary of neighbors.
Expand Down Expand Up @@ -39,7 +69,7 @@ class HNSW(object):
nearest neighbor search. This implementation is based on the paper
"Efficient and robust approximate nearest neighbor search using Hierarchical
Navigable Small World graphs" by Yu. A. Malkov, D. A. Yashunin (2016),
<https://arxiv.org/abs/1603.09320>`_.
`<https://arxiv.org/abs/1603.09320>`_.
Args:
distance_func: A function that takes two vectors and returns a float
Expand All @@ -59,7 +89,7 @@ class HNSW(object):
import hnsw
import numpy as np
data = np.random.random_sample((1000, 10))
index = hnsw.HNSW(distance_func=lambda x, y: np.linalg.norm(x - y), m=5, efConstruction=200)
index = hnsw.HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
for i, d in enumerate(data):
index.insert(i, d)
index.query(data[0], k=10)
Expand All @@ -69,7 +99,7 @@ class HNSW(object):
def __init__(
self,
distance_func: Callable[[np.ndarray, np.ndarray], float],
m: int = 5,
m: int = 16,
ef_construction: int = 200,
m0: Optional[int] = None,
seed: Optional[int] = None,
Expand All @@ -84,24 +114,34 @@ def __init__(
self._entry_point = None
self._random = np.random.RandomState(seed)

def __len__(self):
def __len__(self) -> int:
return len(self._data)

def __contains__(self, key) -> bool:
return key in self._data

def __repr__(self):
return (
f"HNSW({self._distance_func}, m={self._m}, "
f"ef_construction={self._ef_construction}, m0={self._m0}, "
f"num_points={len(self._data)}, num_levels={len(self._graphs)}), "
f"random_state={self._random.get_state()}"
)

def __getitem__(self, key) -> np.ndarray:
"""Get the point associated with the key."""
return self._data[key]

def items(self) -> Iterable[Tuple[Any, np.ndarray]]:
"""Get an iterator over (key, point) pairs in the index."""
return self._data.items()

def keys(self) -> Iterable[Any]:
"""Get an iterator over keys in the index."""
return self._data.keys()

def values(self) -> Iterable[np.ndarray]:
"""Get an iterator over points in the index."""
return self._data.values()

def clear(self) -> None:
"""Clear the index of all data points."""
self._data = {}
self._graphs = []
self._entry_point = None

def insert(
self,
key: Any,
Expand Down Expand Up @@ -172,19 +212,6 @@ def insert(
# We set the entry point for each new level to be the new node.
self._entry_point = key

def clear(self) -> None:
"""Clear the index of all data points.
Args:
None
Returns:
None
"""
self._data = {}
self._graphs = []
self._entry_point = None

def _update(self, key: Any, new_point: np.ndarray, ef: int) -> None:
"""Update the point associated with the key in the index.
Expand Down
23 changes: 23 additions & 0 deletions test/test_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from datasketch.hnsw import HNSW


def l2_distance(x, y):
return np.linalg.norm(x - y)


class TestHNSW(unittest.TestCase):
def test_search_l2(self):
data = np.random.rand(100, 10)
Expand Down Expand Up @@ -91,3 +95,22 @@ def test_update_jaccard(self):
jaccard_func(hnsw[results[j][0]], data[i]),
jaccard_func(hnsw[results[j + 1][0]], data[i]),
)

def test_pickle(self):
data = np.random.rand(100, 10)
hnsw = HNSW(
distance_func=l2_distance,
m=16,
ef_construction=100,
)
for i in range(len(data)):
hnsw.insert(i, data[i])

import pickle

hnsw2 = pickle.loads(pickle.dumps(hnsw))

for i in range(len(data)):
results1 = hnsw.query(data[i], 10)
results2 = hnsw2.query(data[i], 10)
self.assertEqual(results1, results2)

0 comments on commit 9827f2f

Please sign in to comment.