Skip to content

Commit

Permalink
Add gensim.models.BaseKeyedVectors.add_entity method for fill `Keye…
Browse files Browse the repository at this point in the history
…dVectors` in manual way. Fix #1942 (#1957)

* Introduce BaseKeyedVectors.add(...) method

* make default count=1

* add test on add_word method

* address @menshikh-iv comments

* fix test_keyedvectors after removing add_word alias

* add __setitem__, add bulk entities processing + some tests on new functionality

* addressing @menshikh-iv comments on docstrings

* addressing @gojomo comments

* adrressing nitpicks

* make self.vectors = np.zeros((0, vector_size)) by default

* fix pep8
  • Loading branch information
persiyanov authored and menshikh-iv committed Mar 20, 2018
1 parent a781b40 commit 58d560b
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 3 deletions.
66 changes: 63 additions & 3 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@
except ImportError:
PYEMD_EXT = False

from numpy import dot, zeros, float32 as REAL, empty, memmap as np_memmap, \
double, array, vstack, sqrt, newaxis, integer, \
from numpy import dot, float32 as REAL, empty, memmap as np_memmap, \
double, array, zeros, vstack, sqrt, newaxis, integer, \
ndarray, sum as np_sum, prod, argmax, divide as np_divide
import numpy as np
from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc
Expand Down Expand Up @@ -109,7 +109,7 @@ def __str__(self):
class BaseKeyedVectors(utils.SaveLoad):

def __init__(self, vector_size):
self.vectors = []
self.vectors = zeros((0, vector_size))
self.vocab = {}
self.vector_size = vector_size
self.index2entity = []
Expand Down Expand Up @@ -154,6 +154,65 @@ def get_vector(self, entity):
else:
raise KeyError("'%s' not in vocabulary" % entity)

def add(self, entities, weights, replace=False):
"""Add entities and theirs vectors in a manual way.
If some entity is already in the vocabulary, old vector is keeped unless `replace` flag is True.
Parameters
----------
entities : list of str
Entities specified by string tags.
weights: {list of numpy.ndarray, numpy.ndarray}
List of 1D np.array vectors or 2D np.array of vectors.
replace: bool, optional
Flag indicating whether to replace vectors for entities which are already in the vocabulary,
if True - replace vectors, otherwise - keep old vectors.
"""
if isinstance(entities, string_types):
entities = [entities]
weights = np.array(weights).reshape(1, -1)
elif isinstance(weights, list):
weights = np.array(weights)

in_vocab_mask = np.zeros(len(entities), dtype=np.bool)
for idx, entity in enumerate(entities):
if entity in self.vocab:
in_vocab_mask[idx] = True

# add new entities to the vocab
for idx in np.nonzero(~in_vocab_mask)[0]:
entity = entities[idx]
self.vocab[entity] = Vocab(index=len(self.vocab), count=1)
self.index2entity.append(entity)

# add vectors for new entities
self.vectors = vstack((self.vectors, weights[~in_vocab_mask]))

# change vectors for in_vocab entities if `replace` flag is specified
if replace:
in_vocab_idxs = [self.vocab[entities[idx]].index for idx in np.nonzero(in_vocab_mask)[0]]
self.vectors[in_vocab_idxs] = weights[in_vocab_mask]

def __setitem__(self, entities, weights):
"""Add entities and theirs vectors in a manual way.
If some entity is already in the vocabulary, old vector is replaced with the new one.
This method is alias for `add` with `replace=True`.
Parameters
----------
entities : {str, list of str}
Entities specified by string tags.
weights: {list of numpy.ndarray, numpy.ndarray}
List of 1D np.array vectors or 2D np.array of vectors.
"""
if not isinstance(entities, list):
entities = [entities]
weights = weights.reshape(1, -1)

self.add(entities, weights, replace=True)

def __getitem__(self, entities):
"""
Accept a single entity (string tag) or list of entities as input.
Expand All @@ -163,6 +222,7 @@ def __getitem__(self, entities):
If a list, return designated tags' vector representations as a
2D numpy array: #tags x #vector_size.
"""
if isinstance(entities, string_types):
# allow calls like trained_model['office'], as a shorthand for trained_model[['office']]
Expand Down
72 changes: 72 additions & 0 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,78 @@ def test_wv_property(self):
"""Test that the deprecated `wv` property returns `self`. To be removed in v4.0.0."""
self.assertTrue(self.vectors is self.vectors.wv)

def test_add_single(self):
"""Test that adding entity in a manual way works correctly."""
entities = ['___some_entity{}_not_present_in_keyed_vectors___'.format(i) for i in range(5)]
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(5)]

# Test `add` on already filled kv.
for ent, vector in zip(entities, vectors):
self.vectors.add(ent, vector)

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))

# Test `add` on empty kv.
kv = EuclideanKeyedVectors(self.vectors.vector_size)
for ent, vector in zip(entities, vectors):
kv.add(ent, vector)

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(kv[ent], vector))

def test_add_multiple(self):
"""Test that adding a bulk of entities in a manual way works correctly."""
entities = ['___some_entity{}_not_present_in_keyed_vectors___'.format(i) for i in range(5)]
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(5)]

# Test `add` on already filled kv.
vocab_size = len(self.vectors.vocab)
self.vectors.add(entities, vectors, replace=False)
self.assertEqual(vocab_size + len(entities), len(self.vectors.vocab))

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))

# Test `add` on empty kv.
kv = EuclideanKeyedVectors(self.vectors.vector_size)
kv[entities] = vectors
self.assertEqual(len(kv.vocab), len(entities))

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(kv[ent], vector))

def test_set_item(self):
"""Test that __setitem__ works correctly."""
vocab_size = len(self.vectors.vocab)

# Add new entity.
entity = '___some_new_entity___'
vector = np.random.randn(self.vectors.vector_size)
self.vectors[entity] = vector

self.assertEqual(len(self.vectors.vocab), vocab_size + 1)
self.assertTrue(np.allclose(self.vectors[entity], vector))

# Replace vector for entity in vocab.
vocab_size = len(self.vectors.vocab)
vector = np.random.randn(self.vectors.vector_size)
self.vectors['war'] = vector

self.assertEqual(len(self.vectors.vocab), vocab_size)
self.assertTrue(np.allclose(self.vectors['war'], vector))

# __setitem__ on several entities.
vocab_size = len(self.vectors.vocab)
entities = ['war', '___some_new_entity1___', '___some_new_entity2___', 'terrorism', 'conflict']
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(len(entities))]

self.vectors[entities] = vectors

self.assertEqual(len(self.vectors.vocab), vocab_size + 2)
for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down

0 comments on commit 58d560b

Please sign in to comment.