Skip to content

Commit

Permalink
updates keyedvector load tests to use actual values
Browse files Browse the repository at this point in the history
  • Loading branch information
jayantj committed Dec 16, 2016
1 parent 3777423 commit 6e20834
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions gensim/test/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,30 @@ def testSyn0NormNotSaved(self):

def testLoadPreKeyedVectorModel(self):
"""Test loading pre-KeyedVectors word2vec model"""
actual_vector_size, actual_vocab_size = 10, 1750

# Model stored in one file
model = word2vec.Word2Vec.load(datapath('word2vec_pre_kv'))
self.assertTrue(model.wv.syn0.shape == (len(model.wv.vocab), model.vector_size))
self.assertTrue(model.syn1neg.shape == (len(model.wv.vocab), model.vector_size))
self.assertEqual(len(model.wv.vocab), actual_vocab_size)
self.assertEqual(model.vector_size, actual_vector_size)
self.assertEqual(model.wv.syn0.shape, (actual_vocab_size, actual_vector_size))
self.assertEqual(model.syn1neg.shape, (actual_vocab_size, actual_vector_size))

# Model stored in multiple files
model = word2vec.Word2Vec.load(datapath('word2vec_pre_kv_sep'))
self.assertTrue(model.wv.syn0.shape == (len(model.wv.vocab), model.vector_size))
self.assertTrue(model.syn1neg.shape == (len(model.wv.vocab), model.vector_size))
self.assertEqual(len(model.wv.vocab), actual_vocab_size)
self.assertEqual(model.vector_size, actual_vector_size)
self.assertEqual(model.wv.syn0.shape, (actual_vocab_size, actual_vector_size))
self.assertEqual(model.syn1neg.shape, (actual_vocab_size, actual_vector_size))

def testLoadPreKeyedVectorModelCFormat(self):
"""Test loading pre-KeyedVectors word2vec model saved in word2vec format"""
actual_vector_size, actual_vocab_size = 10, 1750

model = word2vec.Word2Vec.load_word2vec_format(datapath('word2vec_pre_kv_c'))
self.assertTrue(model.wv.syn0.shape[0] == len(model.wv.vocab))
self.assertEqual(len(model.wv.vocab), actual_vocab_size)
self.assertEqual(model.vector_size, actual_vector_size)
self.assertEqual(model.wv.syn0.shape, (actual_vocab_size, actual_vector_size))

def testPersistenceWord2VecFormat(self):
"""Test storing/loading the entire model in word2vec format."""
Expand Down

0 comments on commit 6e20834

Please sign in to comment.