diff --git a/reach/reach.py b/reach/reach.py index 13d55a9..dca5100 100644 --- a/reach/reach.py +++ b/reach/reach.py @@ -702,8 +702,12 @@ def normalize(vectors: np.ndarray) -> np.ndarray: else: return vectors / norm[:, None] # type: ignore - def vector_similarity(self, vector: np.ndarray, items: List[str]) -> np.ndarray: + def vector_similarity( + self, vector: np.ndarray, items: Union[str, List[str]] + ) -> np.ndarray: """Compute the similarity between a vector and a set of items.""" + if isinstance(items, str): + items = [items] items_vec = np.stack([self.norm_vectors[self.items[item]] for item in items]) return self._sim(vector, items_vec) diff --git a/tests/test_similarity.py b/tests/test_similarity.py index 6f6f2f9..04bae7f 100644 --- a/tests/test_similarity.py +++ b/tests/test_similarity.py @@ -166,3 +166,20 @@ def test_nearest_neighbor(self) -> None: ] nn2 = instance.threshold(word, threshold=threshold)[0] self.assertEqual(nn1, nn2) + + def test_neighbor_similarity(self) -> None: + words, vectors = self.data() + instance = Reach(vectors, words) + + result = instance.norm_vectors[0] @ instance.norm_vectors[1:].T + result2 = instance.vector_similarity(vectors[0], words[1:]) + + self.assertTrue(np.allclose(result, result2)) + + result = instance.norm_vectors[0] @ instance.norm_vectors[1].T + result2 = instance.vector_similarity(vectors[0], words[1]) + + self.assertEqual(result, result2) + + result3 = instance.vector_similarity(vectors[0], [words[1]]) + self.assertEqual(result2, result3)