-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loading fastText models using only bin file #1341
Changes from 9 commits
7759a95
c12b4fa
8025710
7ee83d9
041a6e9
22c6710
61be613
e11ac44
a63a3bc
f80410f
454d74e
e6b0d8b
9b03ea3
c496be9
2c4a8dd
d2ab903
82507d1
c44b958
0fc1159
f421b05
68ec73b
f7b372e
5f7fe02
8bd56cf
b916187
1a0bfc0
98e0287
f3d2032
bd7e7f6
800cd01
a15233a
431aebf
e52fee4
cebb3fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,11 +35,15 @@ | |
import numpy as np | ||
from numpy import float32 as REAL, sqrt, newaxis | ||
from gensim import utils | ||
from gensim.models.keyedvectors import KeyedVectors | ||
from gensim.models.keyedvectors import KeyedVectors, Vocab | ||
from gensim.models.word2vec import Word2Vec | ||
|
||
from six import string_types | ||
|
||
from numpy import exp, log, dot, zeros, outer, random, dtype, float32 as REAL,\ | ||
double, uint32, seterr, array, uint8, vstack, fromstring, sqrt, newaxis,\ | ||
ndarray, empty, sum as np_sum, prod, ones, ascontiguousarray | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
FASTTEXT_FILEFORMAT_MAGIC = 793712314 | ||
|
@@ -224,7 +228,7 @@ def load_word2vec_format(cls, *args, **kwargs): | |
return FastTextKeyedVectors.load_word2vec_format(*args, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe that a load using this method only learns the full-word vectors as in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this method is not used now for loading using bin only. I removed this unused code, but got a strange flake8 error for python 3+, therefore re-added this for this PR. I'll try removing these unused codes later maybe in a different PR. @gojomo There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is an odd error! I suspect it's not really the presence/absence of that method that triggered it, but something else either random or hidden in the whitespace. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gojomo ok, test passed this time after removing this code 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For reference, this was a bug in the flake8 script, fixed in cebb3fc |
||
|
||
@classmethod | ||
def load_fasttext_format(cls, model_file, encoding='utf8'): | ||
def load_fasttext_format(cls, model_file, bin_only = False, encoding='utf8'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add |
||
""" | ||
Load the input-hidden weight matrix from the fast text output files. | ||
|
||
|
@@ -237,8 +241,11 @@ def load_fasttext_format(cls, model_file, encoding='utf8'): | |
|
||
""" | ||
model = cls() | ||
model.wv = cls.load_word2vec_format('%s.vec' % model_file, encoding=encoding) | ||
model.load_binary_data('%s.bin' % model_file, encoding=encoding) | ||
if bin_only: | ||
model.load_binary_data('%s.bin' % model_file, bin_only, encoding=encoding) | ||
else: | ||
model.wv = cls.load_word2vec_format('%s.vec' % model_file, encoding=encoding) | ||
model.load_binary_data('%s.bin' % model_file, encoding=encoding) | ||
return model | ||
|
||
@classmethod | ||
|
@@ -251,12 +258,12 @@ def delete_training_files(cls, model_file): | |
logger.debug('Training files %s not found when attempting to delete', model_file) | ||
pass | ||
|
||
def load_binary_data(self, model_binary_file, encoding='utf8'): | ||
def load_binary_data(self, model_binary_file, bin_only = False, encoding='utf8'): | ||
"""Loads data from the output binary file created by FastText training""" | ||
with utils.smart_open(model_binary_file, 'rb') as f: | ||
self.load_model_params(f) | ||
self.load_dict(f, encoding=encoding) | ||
self.load_vectors(f) | ||
self.load_dict(f, bin_only, encoding=encoding) | ||
self.load_vectors(f, bin_only) | ||
|
||
def load_model_params(self, file_handle): | ||
magic, version = self.struct_unpack(file_handle, '@2i') | ||
|
@@ -281,15 +288,21 @@ def load_model_params(self, file_handle): | |
self.wv.max_n = maxn | ||
self.sample = t | ||
|
||
def load_dict(self, file_handle, encoding='utf8'): | ||
def load_dict(self, file_handle, bin_only = False, encoding='utf8'): | ||
vocab_size, nwords, _ = self.struct_unpack(file_handle, '@3i') | ||
# Vocab stored by [Dictionary::save](https://github.com/facebookresearch/fastText/blob/master/src/dictionary.cc) | ||
assert len(self.wv.vocab) == nwords, 'mismatch between vocab sizes' | ||
assert len(self.wv.vocab) == vocab_size, 'mismatch between vocab sizes' | ||
if not bin_only: | ||
assert len(self.wv.vocab) == nwords, 'mismatch between vocab sizes' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also log |
||
if len(self.wv.vocab) != vocab_size: | ||
logger.warnings("If you are loading any model other than pretrained vector wiki.fr, ") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
logger.warnings("Please report to gensim or fastText.") | ||
#else: | ||
#self.wv.syn0 = zeros((vocab_size, self.vector_size), dtype=REAL) | ||
# TO-DO : how to update this | ||
self.struct_unpack(file_handle, '@1q') # number of tokens | ||
if self.new_format: | ||
pruneidx_size, = self.struct_unpack(file_handle, '@q') | ||
for i in range(nwords): | ||
for i in range(vocab_size): | ||
word_bytes = b'' | ||
char_byte = file_handle.read(1) | ||
# Read vocab word | ||
|
@@ -298,14 +311,31 @@ def load_dict(self, file_handle, encoding='utf8'): | |
char_byte = file_handle.read(1) | ||
word = word_bytes.decode(encoding) | ||
count, _ = self.struct_unpack(file_handle, '@qb') | ||
assert self.wv.vocab[word].index == i, 'mismatch between gensim word index and fastText word index' | ||
self.wv.vocab[word].count = count | ||
if bin_only: | ||
self.wv.vocab[word] = Vocab(index=i, count=count) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this correct? The word There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, you were right about it, this is the last word, but we can't skip reading it otherwise there will be error in further bytes reading. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I understand we have to ready the bytes, agreed. if i == nwords and i < vocab_size:
assert word == "__label__"
continue # don't add word to vocab
|
||
elif not bin_only: | ||
assert self.wv.vocab[word].index == i, 'mismatch between gensim word index and fastText word index' | ||
self.wv.vocab[word].count = count | ||
|
||
if bin_only: | ||
#self.wv.syn0[i] = weight # How to get weight vector for each word ? | ||
self.wv.index2word.append(word) | ||
|
||
"""if bin_only: | ||
if self.wv.syn0.shape[0] != len(self.wv.vocab): | ||
logger.info( | ||
"duplicate words detected, shrinking matrix size from %i to %i", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understand the need for this. Can duplicate words exist? |
||
self.wv.syn0.shape[0], len(self.wv.vocab) | ||
) | ||
self.wv.syn0 = ascontiguousarray(result.syn0[: len(self.wv.vocab)]) | ||
assert (len(self.wv.vocab), self.vector_size) == self.wv.syn0.shape""" | ||
|
||
if self.new_format: | ||
for j in range(pruneidx_size): | ||
self.struct_unpack(file_handle, '@2i') | ||
|
||
def load_vectors(self, file_handle): | ||
def load_vectors(self, file_handle, bin_only = False): | ||
logger.info("here??") | ||
if self.new_format: | ||
self.struct_unpack(file_handle, '@?') # bool quant_input in fasttext.cc | ||
num_vectors, dim = self.struct_unpack(file_handle, '@2q') | ||
|
@@ -322,13 +352,13 @@ def load_vectors(self, file_handle): | |
self.wv.syn0_all = self.wv.syn0_all.reshape((num_vectors, dim)) | ||
assert self.wv.syn0_all.shape == (self.bucket + len(self.wv.vocab), self.vector_size), \ | ||
'mismatch between weight matrix shape and vocab/model size' | ||
self.init_ngrams() | ||
self.init_ngrams(bin_only) | ||
|
||
def struct_unpack(self, file_handle, fmt): | ||
num_bytes = struct.calcsize(fmt) | ||
return struct.unpack(fmt, file_handle.read(num_bytes)) | ||
|
||
def init_ngrams(self): | ||
def init_ngrams(self, bin_only = False): | ||
""" | ||
Computes ngrams of all words present in vocabulary and stores vectors for only those ngrams. | ||
Vectors for other ngrams are initialized with a random uniform distribution in FastText. These | ||
|
@@ -337,8 +367,26 @@ def init_ngrams(self): | |
""" | ||
self.wv.ngrams = {} | ||
all_ngrams = [] | ||
if bin_only: | ||
self.wv.syn0 = zeros((len(self.wv.vocab), self.vector_size), dtype=REAL) | ||
for w, v in self.wv.vocab.items(): | ||
all_ngrams += self.compute_ngrams(w, self.wv.min_n, self.wv.max_n) | ||
word_ngrams = self.compute_ngrams(w, self.wv.min_n, self.wv.max_n) | ||
all_ngrams += word_ngrams | ||
|
||
|
||
if bin_only: | ||
#self.wv.syn0 = zeros((len(self.wv.vocab), self.vector_size), dtype=REAL) | ||
word_vec = np.zeros(self.wv.syn0.shape[1]) | ||
|
||
num_word_ngram_vectors = len(word_ngrams) | ||
for word_ngram in word_ngrams: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems redundant and possibly error prone. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, so from looking at the FastText code, it seems this is more complicated than what we originally thought.
The large weight matrix in the
This will require a change in logic then. Possibly useful reference - https://github.com/facebookresearch/fastText/blob/master/src/dictionary.cc#L71 |
||
ngram_hash = self.ft_hash(word_ngram) | ||
word_vec += np.array(self.wv.syn0_all[(len(self.wv.vocab) + ngram_hash) % self.bucket]) | ||
|
||
self.wv.syn0[self.wv.vocab[w].index] = word_vec / num_word_ngram_vectors | ||
# Still not working | ||
|
||
|
||
all_ngrams = set(all_ngrams) | ||
self.num_ngram_vectors = len(all_ngrams) | ||
ngram_indices = [] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -171,6 +171,37 @@ def testLoadFastTextNewFormat(self): | |
self.assertEquals(self.test_new_model.wv.min_n, 3) | ||
self.model_sanity(new_model) | ||
|
||
def testLoadBinOnly(self): | ||
""" Test model succesfully loaded from fastText (new format) .bin files only """ | ||
new_model = fasttext.FastText.load_fasttext_format(self.test_new_model_file, bin_only = True) | ||
vocab_size, model_size = 1763, 10 | ||
self.assertEqual(self.test_new_model.wv.syn0.shape, (vocab_size, model_size)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we be testing |
||
self.assertEqual(len(self.test_new_model.wv.vocab), vocab_size, model_size) | ||
self.assertEqual(self.test_new_model.wv.syn0_all.shape, (self.test_new_model.num_ngram_vectors, model_size)) | ||
|
||
expected_vec_new = [-0.025627, | ||
-0.11448, | ||
0.18116, | ||
-0.96779, | ||
0.2532, | ||
-0.93224, | ||
0.3929, | ||
0.12679, | ||
-0.19685, | ||
-0.13179] # obtained using ./fasttext print-word-vectors lee_fasttext_new.bin < queries.txt | ||
|
||
self.assertTrue(numpy.allclose(self.test_new_model["hundred"], expected_vec_new, 0.001)) | ||
self.assertEquals(self.test_new_model.min_count, 5) | ||
self.assertEquals(self.test_new_model.window, 5) | ||
self.assertEquals(self.test_new_model.iter, 5) | ||
self.assertEquals(self.test_new_model.negative, 5) | ||
self.assertEquals(self.test_new_model.sample, 0.0001) | ||
self.assertEquals(self.test_new_model.bucket, 1000) | ||
self.assertEquals(self.test_new_model.wv.max_n, 6) | ||
self.assertEquals(self.test_new_model.wv.min_n, 3) | ||
self.model_sanity(new_model) | ||
|
||
|
||
def testLoadModelWithNonAsciiVocab(self): | ||
"""Test loading model with non-ascii words in vocab""" | ||
model = fasttext.FastText.load_fasttext_format(datapath('non_ascii_fasttext')) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need all these imports?