Skip to content

Commit

Permalink
refactors syn0 word vector lookup into method
Browse files Browse the repository at this point in the history
  • Loading branch information
jayantj committed Sep 12, 2016
1 parent f2d13ce commit 3777423
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ def save(self, *args, **kwargs):
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm'])
super(KeyedVectors, self).save(*args, **kwargs)

def word_vec(self, word, use_norm=False):
if isinstance(word, ndarray):
return word
elif word in self.vocab:
if use_norm:
return self.syn0norm[self.vocab[word].index]
else:
return self.syn0[self.vocab[word].index]
else:
raise KeyError("word '%s' not in vocabulary" % word)

def most_similar(self, positive=[], negative=[], topn=10, restrict_vocab=None, indexer=None):
"""
Find the top-N most similar words. Positive words contribute positively towards the
Expand Down Expand Up @@ -87,13 +98,9 @@ def most_similar(self, positive=[], negative=[], topn=10, restrict_vocab=None, i
# compute the weighted average of all words
all_words, mean = set(), []
for word, weight in positive + negative:
if isinstance(word, ndarray):
mean.append(weight * word)
elif word in self.vocab:
mean.append(weight * self.syn0norm[self.vocab[word].index])
mean.append(weight * self.word_vec(word))
if isinstance(word, string_types) and word in self.vocab:
all_words.add(self.vocab[word].index)
else:
raise KeyError("word '%s' not in vocabulary" % word)
if not mean:
raise ValueError("cannot compute similarity with no input")
mean = matutils.unitvec(array(mean).mean(axis=0)).astype(REAL)
Expand Down Expand Up @@ -227,17 +234,8 @@ def most_similar_cosmul(self, positive=[], negative=[], topn=10):

all_words = set()

def word_vec(word):
if isinstance(word, ndarray):
return word
elif word in self.vocab:
all_words.add(self.vocab[word].index)
return self.syn0norm[self.vocab[word].index]
else:
raise KeyError("word '%s' not in vocabulary" % word)

positive = [word_vec(word) for word in positive]
negative = [word_vec(word) for word in negative]
positive = [self.word_vec(word, use_norm=True) for word in positive]
negative = [self.word_vec(word, use_norm=True) for word in negative]
if not positive:
raise ValueError("cannot compute similarity with no input")

Expand Down Expand Up @@ -310,7 +308,7 @@ def doesnt_match(self, words):
logger.debug("using words %s" % words)
if not words:
raise ValueError("cannot select a word from an empty list")
vectors = vstack(self.syn0norm[self.vocab[word].index] for word in words).astype(REAL)
vectors = vstack(self.word_vec(word) for word in words).astype(REAL)
mean = matutils.unitvec(vectors.mean(axis=0)).astype(REAL)
dists = dot(vectors, mean)
return sorted(zip(dists, words))[0][1]
Expand Down Expand Up @@ -340,9 +338,9 @@ def __getitem__(self, words):
"""
if isinstance(words, string_types):
# allow calls like trained_model['office'], as a shorthand for trained_model[['office']]
return self.syn0[self.vocab[words].index]
return self.word_vec(words)

return vstack([self.syn0[self.vocab[word].index] for word in words])
return vstack([self.word_vec(word) for word in words])

def __contains__(self, word):
return word in self.vocab
Expand Down

0 comments on commit 3777423

Please sign in to comment.