-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix datatype parameter for KeyedVectors.load_word2vec_format
. Fix #1682
#1819
Changes from 8 commits
b923043
35a8f8a
aaa7c2a
a8f44c5
37b39f4
8e095d7
310690d
de98f2e
805daf6
466f37f
a76aec6
049fb91
c157d79
164cf63
991bcb6
0904460
96d8aa5
17f6b39
6f53175
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
2 2 | ||
kangaroo.n.01 8��&�%H��.���horse.n.01 \O�($L���k�P6I? |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
2 2 | ||
kangaroo.n.01 -0.0007369244245224787 -8.269973595356034e-05 | ||
horse.n.01 -0.0008546282343595379 0.0007694142576316829 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Automated tests for checking various matutils functions. | ||
""" | ||
|
||
import logging | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a file header, like in the other test files. |
||
import unittest | ||
|
||
import numpy as np | ||
|
||
from gensim.test.utils import datapath | ||
from gensim.models.keyedvectors import KeyedVectors | ||
|
||
|
||
class TestDataType(unittest.TestCase): | ||
def test_binary(self): | ||
path = datapath('test.kv.bin') | ||
kv = KeyedVectors.load_word2vec_format(path, binary=True, | ||
datatype=np.float64) | ||
self.assertAlmostEqual(kv['horse.n.01'][0], -0.0008546282343595379) | ||
self.assertEqual(kv['horse.n.01'][0].dtype, np.float64) | ||
|
||
def test_text(self): | ||
path = datapath('test.kv.txt') | ||
kv = KeyedVectors.load_word2vec_format(path, binary=False, | ||
datatype=np.float64) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's about different datatypes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will np.float16, np.float32, and np.float64 be enough? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
self.assertAlmostEqual(kv['horse.n.01'][0], -0.0008546282343595379) | ||
self.assertEqual(kv['horse.n.01'][0].dtype, np.float64) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another test to verify that the |
||
|
||
if __name__ == '__main__': | ||
logging.root.setLevel(logging.WARNING) | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am having problem detecting
binary_len
of vector saved with custom datatype. The only clue is that the next vector starts after a" "
but before the space comes a string(also converted in python bytes) which can be of any length. @menshikh-iv any suggestion?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried adding
\n
at the end of each vector during saving in binary but that broke many other tests.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pushpankar how it works if
datatype=REAL
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since earlier float were only saved with only 32 bit precision, knowing the size of each vector in binary format was easy. Casting to lower precision is only done after loading vectors.
Please also note that in develop branch too, casting vector to lower precision and saving it in binary and then loading leads to some errors. This is because while loading float32 is being assumed but during saving it was saved with lower precision like float16. I am adding some code to make it more clear.
Gives
This is because float32 was assumed while reading binary vector but originally it was saved with float16. Thus more than necessary bytes was read for every vector.
Let me know if I am not clear enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so, probably the easiest solution for this case is read/write with
REAL
type & cast it before the end of "load" process, wdyt @jayantj?