Skip to content

Commit

Permalink
Implement changes in utils_any2vec.py as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
pushpankar committed Feb 5, 2018
1 parent 0256756 commit 69a4b2e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
7 changes: 4 additions & 3 deletions gensim/models/utils_any2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ def _save_word2vec_format(fname, vocab, vectors, fvocab=None, binary=False, tota
for word, vocab_ in sorted(iteritems(vocab), key=lambda item: -item[1].count):
row = vectors[vocab_.index]
if binary:
row = row.astype(REAL)
fout.write(utils.to_utf8(word) + b" " + row.tostring())
else:
fout.write(utils.to_utf8("%s %s\n" % (word, ' '.join("%f" % val for val in row))))
fout.write(utils.to_utf8("%s %s\n" % (word, ' '.join(repr(val) for val in row))))


def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict',
Expand Down Expand Up @@ -205,7 +206,7 @@ def add_word(word, weights):
if ch != b'\n': # ignore newlines in front of words (some binary files have)
word.append(ch)
word = utils.to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors)
weights = fromstring(fin.read(binary_len), dtype=REAL)
weights = fromstring(fin.read(binary_len), dtype=REAL).astype(datatype)
add_word(word, weights)
else:
for line_no in xrange(vocab_size):
Expand All @@ -215,7 +216,7 @@ def add_word(word, weights):
parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ")
if len(parts) != vector_size + 1:
raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no)
word, weights = parts[0], [REAL(x) for x in parts[1:]]
word, weights = parts[0], [datatype(x) for x in parts[1:]]
add_word(word, weights)
if result.vectors.shape[0] != len(result.vocab):
logger.info(
Expand Down
10 changes: 10 additions & 0 deletions gensim/test/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def load_model(self, datatype):

def test_high_precision(self):
kv = self.load_model(np.float64)
import pdb
pdb.set_trace()
self.assertAlmostEqual(kv['horse.n.01'][0], -0.0008546282343595379)
self.assertEqual(kv['horse.n.01'][0].dtype, np.float64)

Expand All @@ -38,6 +40,14 @@ def test_low_precision(self):
self.assertAlmostEqual(kv['horse.n.01'][0], -0.00085449)
self.assertEqual(kv['horse.n.01'][0].dtype, np.float16)

def test_type_conversion(self):
path = datapath('test.kv.txt')
binary_path = datapath('test.kv.bin')
model1 = KeyedVectors.load_word2vec_format(path, datatype=np.float16)
model1.save_word2vec_format(binary_path, binary=True)
model2 = KeyedVectors.load_word2vec_format(binary_path, datatype=np.float64, binary=True)
self.assertAlmostEqual(model1["horse.n.01"][0], np.float16(model2["horse.n.01"][0]))


if __name__ == '__main__':
logging.root.setLevel(logging.WARNING)
Expand Down

0 comments on commit 69a4b2e

Please sign in to comment.