From 4d223279985f811c0bbfc7f1451c7209c920c859 Mon Sep 17 00:00:00 2001 From: lopusz Date: Thu, 23 Jan 2020 08:25:20 +0100 Subject: [PATCH] Implement saving to Facebook format (#2712) * Add writing header for binary FB format (#2611) * Adding writing vocabulary, vectors, output layer for FB format (#2611) * Clean up writing to binary FB format (#2611) * Adding tests for saving FastText models to binary FB format (#2611) * Extending tests for saving FastText models to binary FB format (#2611) * Clean up (flake8) writing to binary FB format (#2611) * Word count bug fix + including additional test (#2611) * Removing f-strings for Python 3.5 compatibility + clean-up(#2611) * Clean up the comments (#2611) * Removing forgotten f-string for Python 3.5 compatibility (#2611) * Correct tests failing @ CI (#2611) * Another attempt to correct tests failing @ CI (#2611) * Yet another attempt to correct tests failing @ CI (#2611) * New attempt to correct tests failing @ CI (#2611) * Fix accidentally broken test (#2611) * Include Radim remarks to saving models in binary FB format (#2611) * Correcting loss bug (#2611) * Completed correcting loss bug (#2611) * Correcting breaking doc building bug (#2611) * Include first batch of Michael remarks * Refactoring SaveFacebookFormatRoundtripModelToModelTest according to Michael remarks (#2611) * Refactoring remaining tests according to Michael remarks (#2611) * Cleaning up the test refactoring (#2611) * Refactoring handling tuple result from struct.unpack (#2611) * Removing unused import (#2611) * Refactoring variable name according to Michael review (#2611) * Removing redundant saving in test for Facebook binary saving (#2611) * Minimizing context manager blocks span (#2611) * Remove obsolete comment (#2611) * Shortening method name (#2611) * Moving model parameters to _check_roundtrip function (#2611) * Finished moving model parameters to _check_roundtrip function (#2611) * Clean-up FT_HOME behaviour (#2611) * Simplifying vectors equality check (#2611) * Unifying testing method names (#2611) * Refactoring _create_and_save_fb_model method name (#2611) * Refactoring test names (#2611) * Refactoring flake8 errors (#2611) * Correcting fasttext invocation handling (#2611) * Removing _parse_wordvectors function (#2611) * Correcting whitespace and simplifying test assertion (#2611) * Removing redundant anonymous variable (#2611) * Moving assertion outside of a context manager (#2611) * Function rename (#2611) * Cleaning doc strings and comments in FB binary format saving functionality (#2611) * Cleaning doc strings in end-user API for FB binary format saving (#2611) * Correcting FT_CMD execution in SaveFacebookByteIdentityTest (#2611) --- gensim/models/_fasttext_bin.py | 322 ++++++++++++++++++++++++++++++++- gensim/models/fasttext.py | 52 +++++- gensim/test/test_fasttext.py | 226 +++++++++++++++++++++-- 3 files changed, 571 insertions(+), 29 deletions(-) diff --git a/gensim/models/_fasttext_bin.py b/gensim/models/_fasttext_bin.py index f00b049ee4..3b7af85f9e 100644 --- a/gensim/models/_fasttext_bin.py +++ b/gensim/models/_fasttext_bin.py @@ -41,9 +41,25 @@ _END_OF_WORD_MARKER = b'\x00' +# FastText dictionary data structure holds elements of type `entry` which can have `entry_type` +# either `word` (0 :: int8) or `label` (1 :: int8). Here we deal with unsupervised case only +# so we want `word` type. +# See https://github.com/facebookresearch/fastText/blob/master/src/dictionary.h + +_DICT_WORD_ENTRY_TYPE_MARKER = b'\x00' + + logger = logging.getLogger(__name__) -_FASTTEXT_FILEFORMAT_MAGIC = 793712314 +# Constants for FastText vesrion and FastText file format magic (both int32) +# https://github.com/facebookresearch/fastText/blob/master/src/fasttext.cc#L25 + +_FASTTEXT_VERSION = np.int32(12) +_FASTTEXT_FILEFORMAT_MAGIC = np.int32(793712314) + + +# _NEW_HEADER_FORMAT is constructed on the basis of args::save method, see +# https://github.com/facebookresearch/fastText/blob/master/src/args.cc _NEW_HEADER_FORMAT = [ ('dim', 'i'), @@ -51,13 +67,13 @@ ('epoch', 'i'), ('min_count', 'i'), ('neg', 'i'), - ('_', 'i'), + ('word_ngrams', 'i'), # Unused in loading ('loss', 'i'), ('model', 'i'), ('bucket', 'i'), ('minn', 'i'), ('maxn', 'i'), - ('_', 'i'), + ('lr_update_rate', 'i'), # Unused in loading ('t', 'd'), ] @@ -65,13 +81,13 @@ ('epoch', 'i'), ('min_count', 'i'), ('neg', 'i'), - ('_', 'i'), + ('word_ngrams', 'i'), # Unused in loading ('loss', 'i'), ('model', 'i'), ('bucket', 'i'), ('minn', 'i'), ('maxn', 'i'), - ('_', 'i'), + ('lr_update_rate', 'i'), # Unused in loading ('t', 'd'), ] @@ -93,6 +109,7 @@ def _yield_field_names(): yield 'nwords' yield 'vectors_ngrams' yield 'hidden_output' + yield 'ntokens' _FIELD_NAMES = sorted(set(_yield_field_names())) @@ -168,6 +185,7 @@ def _load_vocab(fin, new_format, encoding='utf-8'): The loaded vocabulary. Keys are words, values are counts. The vocabulary size. The number of words. + The number of tokens. """ vocab_size, nwords, nlabels = _struct_unpack(fin, '@3i') @@ -176,7 +194,8 @@ def _load_vocab(fin, new_format, encoding='utf-8'): raise NotImplementedError("Supervised fastText models are not supported") logger.info("loading %s words for fastText model from %s", vocab_size, fin.name) - _struct_unpack(fin, '@1q') # number of tokens + ntokens = _struct_unpack(fin, '@q')[0] # number of tokens + if new_format: pruneidx_size, = _struct_unpack(fin, '@q') @@ -205,7 +224,7 @@ def _load_vocab(fin, new_format, encoding='utf-8'): for j in range(pruneidx_size): _struct_unpack(fin, '@2i') - return raw_vocab, vocab_size, nwords + return raw_vocab, vocab_size, nwords, ntokens def _load_matrix(fin, new_format=True): @@ -315,11 +334,12 @@ def load(fin, encoding='utf-8', full_model=True): header_spec = _NEW_HEADER_FORMAT if new_format else _OLD_HEADER_FORMAT model = {name: _struct_unpack(fin, fmt)[0] for (name, fmt) in header_spec} + if not new_format: model.update(dim=magic, ws=version) - raw_vocab, vocab_size, nwords = _load_vocab(fin, new_format, encoding=encoding) - model.update(raw_vocab=raw_vocab, vocab_size=vocab_size, nwords=nwords) + raw_vocab, vocab_size, nwords, ntokens = _load_vocab(fin, new_format, encoding=encoding) + model.update(raw_vocab=raw_vocab, vocab_size=vocab_size, nwords=nwords, ntokens=ntokens) vectors_ngrams = _load_matrix(fin, new_format=new_format) @@ -366,5 +386,289 @@ def _backslashreplace_backport(ex): return text, end +def _sign_model(fout): + """ + Write signature of the file in Facebook's native fastText `.bin` format + to the binary output stream `fout`. Signature includes magic bytes and version. + + Name mimics original C++ implementation, see + [FastText::signModel](https://github.com/facebookresearch/fastText/blob/master/src/fasttext.cc) + + Parameters + ---------- + fout: writeable binary stream + """ + fout.write(_FASTTEXT_FILEFORMAT_MAGIC.tobytes()) + fout.write(_FASTTEXT_VERSION.tobytes()) + + +def _conv_field_to_bytes(field_value, field_type): + """ + Auxiliary function that converts `field_value` to bytes based on request `field_type`, + for saving to the binary file. + + Parameters + ---------- + field_value: numerical + contains arguments of the string and start/end indexes of the bad portion. + + field_type: str + currently supported `field_types` are `i` for 32-bit integer and `d` for 64-bit float + """ + if field_type == 'i': + return (np.int32(field_value).tobytes()) + elif field_type == 'd': + return (np.float64(field_value).tobytes()) + else: + raise NotImplementedError('Currently conversion to "%s" type is not implemmented.' % field_type) + + +def _get_field_from_model(model, field): + """ + Extract `field` from `model`. + + Parameters + ---------- + model: gensim.models.fasttext.FastText + model from which `field` is extracted + field: str + requested field name, fields are listed in the `_NEW_HEADER_FORMAT` list + """ + if field == 'bucket': + return model.trainables.bucket + elif field == 'dim': + return model.vector_size + elif field == 'epoch': + return model.epochs + elif field == 'loss': + # `loss` => hs: 1, ns: 2, softmax: 3, ova-vs-all: 4 + # ns = negative sampling loss (default) + # hs = hierarchical softmax loss + # softmax = softmax loss + # one-vs-all = one vs all loss (supervised) + if model.hs == 1: + return 1 + elif model.hs == 0: + return 2 + elif model.hs == 0 and model.negative == 0: + return 1 + elif field == 'maxn': + return model.wv.max_n + elif field == 'minn': + return model.wv.min_n + elif field == 'min_count': + return model.vocabulary.min_count + elif field == 'model': + # `model` => cbow:1, sg:2, sup:3 + # cbow = continous bag of words (default) + # sg = skip-gram + # sup = supervised + return 2 if model.sg == 1 else 1 + elif field == 'neg': + return model.negative + elif field == 't': + return model.vocabulary.sample + elif field == 'word_ngrams': + # This is skipped in gensim loading setting, using the default from FB C++ code + return 1 + elif field == 'ws': + return model.window + elif field == 'lr_update_rate': + # This is skipped in gensim loading setting, using the default from FB C++ code + return 100 + else: + msg = 'Extraction of header field "' + field + '" from Gensim FastText object not implemmented.' + raise NotImplementedError(msg) + + +def _args_save(fout, model, fb_fasttext_parameters): + """ + Saves header with `model` parameters to the binary stream `fout` containing a model in the Facebook's + native fastText `.bin` format. + + Name mimics original C++ implementation, see + [Args::save](https://github.com/facebookresearch/fastText/blob/master/src/args.cc) + + Parameters + ---------- + fout: writeable binary stream + stream to which model is saved + model: gensim.models.fasttext.FastText + saved model + fb_fasttext_parameters: dictionary + dictionary contain parameters containing `lr_update_rate`, `word_ngrams` + unused by gensim implementation, so they have to be provided externally + """ + for field, field_type in _NEW_HEADER_FORMAT: + if field in fb_fasttext_parameters: + field_value = fb_fasttext_parameters[field] + else: + field_value = _get_field_from_model(model, field) + fout.write(_conv_field_to_bytes(field_value, field_type)) + + +def _dict_save(fout, model, encoding): + """ + Saves the dictionary from `model` to the to the binary stream `fout` containing a model in the Facebook's + native fastText `.bin` format. + + Name mimics the original C++ implementation + [Dictionary::save](https://github.com/facebookresearch/fastText/blob/master/src/dictionary.cc) + + Parameters + ---------- + fout: writeable binary stream + stream to which the dictionary from the model is saved + model: gensim.models.fasttext.FastText + the model that contains the dictionary to save + encoding: str + string encoding used in the output + """ + + # In the FB format the dictionary can contain two types of entries, i.e. + # words and labels. The first two fields of the dictionary contain + # the dictionary size (size_) and the number of words (nwords_). + # In the unsupervised case we have only words (no labels). Hence both fields + # are equal. + + fout.write(np.int32(len(model.wv.vocab)).tobytes()) + + fout.write(np.int32(len(model.wv.vocab)).tobytes()) + + # nlabels=0 <- no labels we are in unsupervised mode + fout.write(np.int32(0).tobytes()) + + fout.write(np.int64(model.corpus_total_words).tobytes()) + + # prunedidx_size_=-1, -1 value denotes no prunning index (prunning is only supported in supervised mode) + fout.write(np.int64(-1)) + + for word in model.wv.index2word: + word_count = model.wv.vocab[word].count + fout.write(word.encode(encoding)) + fout.write(_END_OF_WORD_MARKER) + fout.write(np.int64(word_count).tobytes()) + fout.write(_DICT_WORD_ENTRY_TYPE_MARKER) + + # We are in unsupervised case, therefore pruned_idx is empty, so we do not need to write anything else + + +def _input_save(fout, model): + """ + Saves word and ngram vectors from `model` to the binary stream `fout` containing a model in + the Facebook's native fastText `.bin` format. + + Corresponding C++ fastText code: + [DenseMatrix::save](https://github.com/facebookresearch/fastText/blob/master/src/densematrix.cc) + + Parameters + ---------- + fout: writeable binary stream + stream to which the vectors are saved + model: gensim.models.fasttext.FastText + the model that contains the vectors to save + """ + vocab_n, vocab_dim = model.wv.vectors_vocab.shape + ngrams_n, ngrams_dim = model.wv.vectors_ngrams.shape + + assert vocab_dim == ngrams_dim + assert vocab_n == len(model.wv.vocab) + assert ngrams_n == model.wv.bucket + + fout.write(struct.pack('@2q', vocab_n + ngrams_n, vocab_dim)) + fout.write(model.wv.vectors_vocab.tobytes()) + fout.write(model.wv.vectors_ngrams.tobytes()) + + +def _output_save(fout, model): + """ + Saves output layer of `model` to the binary stream `fout` containing a model in + the Facebook's native fastText `.bin` format. + + Corresponding C++ fastText code: + [DenseMatrix::save](https://github.com/facebookresearch/fastText/blob/master/src/densematrix.cc) + + Parameters + ---------- + fout: writeable binary stream + the model that contains the output layer to save + model: gensim.models.fasttext.FastText + saved model + """ + if model.hs: + hidden_output = model.trainables.syn1 + if model.negative: + hidden_output = model.trainables.syn1neg + + hidden_n, hidden_dim = hidden_output.shape + fout.write(struct.pack('@2q', hidden_n, hidden_dim)) + fout.write(hidden_output.tobytes()) + + +def _save_to_stream(model, fout, fb_fasttext_parameters, encoding): + """ + Saves word embeddings to binary stream `fout` using the Facebook's native fasttext `.bin` format. + + Parameters + ---------- + fout: file name or writeable binary stream + stream to which the word embeddings are saved + model: gensim.models.fasttext.FastText + the model that contains the word embeddings to save + fb_fasttext_parameters: dictionary + dictionary contain parameters containing `lr_update_rate`, `word_ngrams` + unused by gensim implementation, so they have to be provided externally + encoding: str + encoding used in the output file + """ + + _sign_model(fout) + _args_save(fout, model, fb_fasttext_parameters) + _dict_save(fout, model, encoding) + fout.write(struct.pack('@?', False)) # Save 'quant_', which is False for unsupervised models + + # Save words and ngrams vectors + _input_save(fout, model) + fout.write(struct.pack('@?', False)) # Save 'quot_', which is False for unsupervised models + + # Save output layers of the model + _output_save(fout, model) + + +def save(model, fout, fb_fasttext_parameters, encoding): + """ + Saves word embeddings to the Facebook's native fasttext `.bin` format. + + Parameters + ---------- + fout: file name or writeable binary stream + stream to which model is saved + model: gensim.models.fasttext.FastText + saved model + fb_fasttext_parameters: dictionary + dictionary contain parameters containing `lr_update_rate`, `word_ngrams` + unused by gensim implementation, so they have to be provided externally + encoding: str + encoding used in the output file + + Notes + ----- + Unfortunately, there is no documentation of the Facebook's native fasttext `.bin` format + + This is just reimplementation of + [FastText::saveModel](https://github.com/facebookresearch/fastText/blob/master/src/fasttext.cc) + + Based on v0.9.1, more precisely commit da2745fcccb848c7a225a7d558218ee4c64d5333 + + Code follows the original C++ code naming. + """ + + if isinstance(fout, str): + with open(fout, "wb") as fout_stream: + _save_to_stream(model, fout_stream, fb_fasttext_parameters, encoding) + else: + _save_to_stream(model, fout, fb_fasttext_parameters, encoding) + + if six.PY2: codecs.register_error('backslashreplace', _backslashreplace_backport) diff --git a/gensim/models/fasttext.py b/gensim/models/fasttext.py index 82af368508..2e4ad5fa64 100644 --- a/gensim/models/fasttext.py +++ b/gensim/models/fasttext.py @@ -1216,15 +1216,15 @@ def _load_fasttext_format(model_file, encoding='utf-8', full_model=True): window=m.ws, iter=m.epoch, negative=m.neg, - hs=(m.loss == 1), - sg=(m.model == 2), + hs=int(m.loss == 1), + sg=int(m.model == 2), bucket=m.bucket, min_count=m.min_count, sample=m.t, min_n=m.minn, max_n=m.maxn, ) - + model.corpus_total_words = m.ntokens model.vocabulary.raw_vocab = m.raw_vocab model.vocabulary.nwords = m.nwords model.vocabulary.vocab_size = m.vocab_size @@ -1250,7 +1250,6 @@ def _load_fasttext_format(model_file, encoding='utf-8', full_model=True): model.wv.init_post_load(m.vectors_ngrams) model.trainables.init_post_load(model, m.hidden_output) - _check_model(model) logger.info("loaded %s weight matrix for fastText model from %s", m.vectors_ngrams.shape, fin.name) @@ -1265,7 +1264,13 @@ def _check_model(m): 'mismatch between vector size in model params ({}) and model vectors ({})' .format(m.wv.vector_size, m.wv.vectors_ngrams) ) - if m.trainables.syn1neg is not None: + + try: + syn1neg = m.trainables.syn1neg + except AttributeError: + syn1neg = None + + if syn1neg is not None: assert m.wv.vector_size == m.trainables.syn1neg.shape[1], ( 'mismatch between vector size in model params ({}) and trainables ({})' .format(m.wv.vector_size, m.wv.vectors_ngrams) @@ -1282,3 +1287,40 @@ def _check_model(m): "mismatch between final vocab size (%s words), and expected vocab size (%s words)", len(m.wv.vocab), m.vocabulary.vocab_size ) + + +def save_facebook_model(model, path, encoding="utf-8", lr_update_rate=100, word_ngrams=1): + """Saves word embeddings to the Facebook's native fasttext `.bin` format. + + Notes + ------ + Facebook provides both `.vec` and `.bin` files with their modules. + The former contains human-readable vectors. + The latter contains machine-readable vectors along with other model parameters. + **This function saves only the .bin file**. + + Parameters + ---------- + model : gensim.models.fasttext.FastText + FastText model to be saved. + path : str + Output path and filename (including `.bin` extension) + encoding : str, optional + Specifies the file encoding. Defaults to utf-8. + + lr_update_rate : int + This parameter is used by Facebook fasttext tool, unused by Gensim. + It defaults to Facebook fasttext default value `100`. + In very rare circumstances you might wish to fiddle with it. + + word_ngrams : int + This parameter is used by Facebook fasttext tool, unused by Gensim. + It defaults to Facebook fasttext default value `1`. + In very rare circumstances you might wish to fiddle with it. + + Returns + ------- + None + """ + fb_fasttext_parameters = {"lr_update_rate": lr_update_rate, "word_ngrams": word_ngrams} + gensim.models._fasttext_bin.save(model, path, fb_fasttext_parameters, encoding) diff --git a/gensim/test/test_fasttext.py b/gensim/test/test_fasttext.py index 4a5626eb30..b8de819baa 100644 --- a/gensim/test/test_fasttext.py +++ b/gensim/test/test_fasttext.py @@ -7,6 +7,7 @@ import logging import unittest import os +import subprocess import struct import numpy as np @@ -33,6 +34,11 @@ IS_WIN32 = (os.name == "nt") and (struct.calcsize('P') * 8 == 32) +MAX_WORDVEC_COMPONENT_DIFFERENCE = 1.0e-10 + +FT_HOME = os.environ.get("FT_HOME") +FT_CMD = os.path.join(FT_HOME, "fasttext") if FT_HOME else None + class LeeCorpus(object): def __iter__(self): @@ -56,8 +62,6 @@ def __iter__(self): class TestFastTextModel(unittest.TestCase): def setUp(self): - ft_home = os.environ.get('FT_HOME', None) - self.ft_path = os.path.join(ft_home, 'fasttext') if ft_home else None self.test_model_file = datapath('lee_fasttext.bin') self.test_model = gensim.models.fasttext.load_facebook_model(self.test_model_file) self.test_new_model_file = datapath('lee_fasttext_new.bin') @@ -811,13 +815,10 @@ def compare_with_wrapper(self, model_gensim, model_wrapper): # this limit can be increased when using Cython code self.assertGreaterEqual(overlap_count, 2) + @unittest.skipIf(not FT_HOME, "FT_HOME env variable not set, skipping test") def test_cbow_hs_against_wrapper(self): - if self.ft_path is None: - logger.info("FT_HOME env variable not set, skipping test") - return - tmpf = get_tmpfile('gensim_fasttext.tst') - model_wrapper = FT_wrapper.train(ft_path=self.ft_path, corpus_file=datapath('lee_background.cor'), + model_wrapper = FT_wrapper.train(ft_path=FT_CMD, corpus_file=datapath('lee_background.cor'), output_file=tmpf, model='cbow', size=50, alpha=0.05, window=5, min_count=5, word_ngrams=1, loss='hs', sample=1e-3, negative=0, iter=5, min_n=3, max_n=6, sorted_vocab=1, @@ -834,13 +835,11 @@ def test_cbow_hs_against_wrapper(self): self.assertFalse((orig0 == model_gensim.wv.vectors[0]).all()) # vector should vary after training self.compare_with_wrapper(model_gensim, model_wrapper) + @unittest.skipIf(not FT_HOME, "FT_HOME env variable not set, skipping test") def test_sg_hs_against_wrapper(self): - if self.ft_path is None: - logger.info("FT_HOME env variable not set, skipping test") - return tmpf = get_tmpfile('gensim_fasttext.tst') - model_wrapper = FT_wrapper.train(ft_path=self.ft_path, corpus_file=datapath('lee_background.cor'), + model_wrapper = FT_wrapper.train(ft_path=FT_CMD, corpus_file=datapath('lee_background.cor'), output_file=tmpf, model='skipgram', size=50, alpha=0.025, window=5, min_count=5, word_ngrams=1, loss='hs', sample=1e-3, negative=0, iter=5, min_n=3, max_n=6, sorted_vocab=1, @@ -1218,7 +1217,7 @@ def test_ascii(self): buf = io.BytesIO() buf.name = 'dummy name to keep fasttext happy' buf.write(struct.pack('@3i', 2, -1, -1)) # vocab_size, nwords, nlabels - buf.write(struct.pack('@1q', -1)) + buf.write(struct.pack('@1q', 10)) # ntokens buf.write(b'hello') buf.write(b'\x00') buf.write(struct.pack('@qb', 1, -1)) @@ -1227,18 +1226,20 @@ def test_ascii(self): buf.write(struct.pack('@qb', 2, -1)) buf.seek(0) - raw_vocab, vocab_size, nlabels = gensim.models._fasttext_bin._load_vocab(buf, False) + raw_vocab, vocab_size, nlabels, ntokens = gensim.models._fasttext_bin._load_vocab(buf, False) expected = {'hello': 1, 'world': 2} self.assertEqual(expected, dict(raw_vocab)) self.assertEqual(vocab_size, 2) self.assertEqual(nlabels, -1) + self.assertEqual(ntokens, 10) + def test_bad_unicode(self): buf = io.BytesIO() buf.name = 'dummy name to keep fasttext happy' buf.write(struct.pack('@3i', 2, -1, -1)) # vocab_size, nwords, nlabels - buf.write(struct.pack('@1q', -1)) + buf.write(struct.pack('@1q', 10)) # ntokens # # encountered in https://github.com/RaRe-Technologies/gensim/issues/2378 # The model from downloaded from @@ -1265,7 +1266,7 @@ def test_bad_unicode(self): buf.write(struct.pack('@qb', 2, -1)) buf.seek(0) - raw_vocab, vocab_size, nlabels = gensim.models._fasttext_bin._load_vocab(buf, False) + raw_vocab, vocab_size, nlabels, ntokens = gensim.models._fasttext_bin._load_vocab(buf, False) expected = { u'英語版ウィキペディアへの投稿はいつでも\\xe6': 1, @@ -1276,6 +1277,7 @@ def test_bad_unicode(self): self.assertEqual(vocab_size, 2) self.assertEqual(nlabels, -1) + self.assertEqual(ntokens, 10) _BYTES = b'the quick brown fox jumps over the lazy dog' @@ -1300,6 +1302,200 @@ def _run(self, fin): self.assertTrue(np.allclose(_ARRAY, array)) +def _create_and_save_fb_model(fname, model_params): + model = FT_gensim(**model_params) + lee_data = LineSentence(datapath('lee_background.cor')) + model.build_vocab(lee_data) + model.train(lee_data, total_examples=model.corpus_count, epochs=model.epochs) + gensim.models.fasttext.save_facebook_model(model, fname) + return model + + +def calc_max_diff(v1, v2): + return np.max(np.abs(v1 - v2)) + + +class SaveFacebookFormatModelTest(unittest.TestCase): + + def _check_roundtrip(self, sg): + model_params = { + "sg": sg, + "size": 10, + "min_count": 1, + "hs": 1, + "negative": 5, + "seed": 42, + "workers": 1} + + with temporary_file("roundtrip_model_to_model.bin") as fpath: + model_trained = _create_and_save_fb_model(fpath, model_params) + model_loaded = gensim.models.fasttext.load_facebook_model(fpath) + + self.assertEqual(model_trained.vector_size, model_loaded.vector_size) + self.assertEqual(model_trained.window, model_loaded.window) + self.assertEqual(model_trained.epochs, model_loaded.epochs) + self.assertEqual(model_trained.negative, model_loaded.negative) + self.assertEqual(model_trained.hs, model_loaded.hs) + self.assertEqual(model_trained.sg, model_loaded.sg) + self.assertEqual(model_trained.trainables.bucket, model_loaded.trainables.bucket) + self.assertEqual(model_trained.wv.min_n, model_loaded.wv.min_n) + self.assertEqual(model_trained.wv.max_n, model_loaded.wv.max_n) + self.assertEqual(model_trained.vocabulary.sample, model_loaded.vocabulary.sample) + self.assertEqual(set(model_trained.wv.index2word), set(model_loaded.wv.index2word)) + + for w in model_trained.wv.index2word: + v_orig = model_trained.wv[w] + v_loaded = model_loaded.wv[w] + self.assertLess(calc_max_diff(v_orig, v_loaded), MAX_WORDVEC_COMPONENT_DIFFERENCE) + + def test_skipgram(self): + self._check_roundtrip(sg=1) + + def test_cbow(self): + self._check_roundtrip(sg=0) + + +def _read_binary_file(fname): + with open(fname, "rb") as f: + data = f.read() + return data + + +class SaveGensimByteIdentityTest(unittest.TestCase): + """ + This class containts tests that check the following scenario: + + + create binary fastText file model1.bin using gensim + + load file model1.bin to variable `model` + + save `model` to model2.bin + + check if files model1.bin and model2.bin are byte identical + """ + + def _check_roundtrip_file_file(self, sg): + model_params = { + "sg": sg, + "size": 10, + "min_count": 1, + "hs": 1, + "negative": 0, + "seed": 42, + "workers": 1} + + with temporary_file("roundtrip_file_to_file1.bin") as fpath1, \ + temporary_file("roundtrip_file_to_file2.bin") as fpath2: + _create_and_save_fb_model(fpath1, model_params) + model = gensim.models.fasttext.load_facebook_model(fpath1) + gensim.models.fasttext.save_facebook_model(model, fpath2) + bin1 = _read_binary_file(fpath1) + bin2 = _read_binary_file(fpath2) + + self.assertEqual(bin1, bin2) + + def test_skipgram(self): + self._check_roundtrip_file_file(sg=1) + + def test_cbow(self): + self._check_roundtrip_file_file(sg=0) + + +def _save_test_model(out_base_fname, model_params): + inp_fname = datapath('lee_background.cor') + + model_type = "cbow" if model_params["sg"] == 0 else "skipgram" + size = str(model_params["size"]) + seed = str(model_params["seed"]) + + cmd = [ + FT_CMD, model_type, "-input", inp_fname, "-output", + out_base_fname, "-dim", size, "-seed", seed] + + subprocess.check_call(cmd) + + +@unittest.skipIf(not FT_HOME, "FT_HOME env variable not set, skipping test") +class SaveFacebookByteIdentityTest(unittest.TestCase): + """ + This class containts tests that check the following scenario: + + + create binary fastText file model1.bin using facebook_binary (FT) + + load file model1.bin to variable `model` + + save `model` to model2.bin using gensim + + check if files model1.bin and model2.bin are byte-identical + """ + + def _check_roundtrip_file_file(self, sg): + model_params = {"size": 10, "sg": sg, "seed": 42} + + # fasttext tool creates both *vec and *bin files, so we have to remove both, even thought *vec is unused + + with temporary_file("m1.bin") as m1, temporary_file("m2.bin") as m2, temporary_file("m1.vec"): + + m1_basename = m1[:-4] + _save_test_model(m1_basename, model_params) + model = gensim.models.fasttext.load_facebook_model(m1) + gensim.models.fasttext.save_facebook_model(model, m2) + bin1 = _read_binary_file(m1) + bin2 = _read_binary_file(m2) + + self.assertEqual(bin1, bin2) + + def test_skipgram(self): + self._check_roundtrip_file_file(sg=1) + + def test_cbow(self): + self._check_roundtrip_file_file(sg=0) + + +def _read_wordvectors_using_fasttext(fasttext_fname, words): + def line_to_array(line): + return np.array([float(s) for s in line.split()[1:]], dtype=np.float32) + + cmd = [FT_CMD, "print-word-vectors", fasttext_fname] + process = subprocess.Popen( + cmd, stdin=subprocess.PIPE, + stdout=subprocess.PIPE) + words_str = '\n'.join(words) + out, _ = process.communicate(input=words_str.encode("utf-8")) + return np.array([line_to_array(l) for l in out.splitlines()], dtype=np.float32) + + +@unittest.skipIf(not os.environ.get("FT_HOME", None), "FT_HOME env variable not set, skipping test") +class SaveFacebookFormatReadingTest(unittest.TestCase): + """ + This class containts tests that check the following scenario: + + + create fastText model using gensim + + save file to model.bin + + retrieve word vectors from model.bin using fasttext Facebook utility + + compare vectors retrieved by Facebook utility with those obtained directly from gensim model + """ + + def _check_load_fasttext_format(self, sg): + model_params = { + "sg": sg, + "size": 10, + "min_count": 1, + "hs": 1, + "negative": 5, + "seed": 42, + "workers": 1} + + with temporary_file("load_fasttext.bin") as fpath: + model = _create_and_save_fb_model(fpath, model_params) + wv = _read_wordvectors_using_fasttext(fpath, model.wv.index2word) + + for i, w in enumerate(model.wv.index2word): + diff = calc_max_diff(wv[i, :], model.wv[w]) + # Because fasttext command line prints vectors with limited accuracy + self.assertLess(diff, 1.0e-4) + + def test_skipgram(self): + self._check_load_fasttext_format(sg=1) + + def test_cbow(self): + self._check_load_fasttext_format(sg=0) + + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) unittest.main()