Skip to content

Commit

Permalink
Implement saving to Facebook format (#2712)
Browse files Browse the repository at this point in the history
* Add writing header for binary FB format (#2611)

* Adding writing vocabulary, vectors, output layer for FB format (#2611)

* Clean up writing to binary FB format (#2611)

* Adding tests for saving FastText models to binary FB format (#2611)

* Extending tests for saving FastText models to binary FB format (#2611)

* Clean up (flake8) writing to binary FB format (#2611)

* Word count bug fix + including additional test (#2611)

* Removing f-strings for Python 3.5 compatibility + clean-up(#2611)

* Clean up the comments (#2611)

* Removing forgotten f-string for Python 3.5 compatibility (#2611)

* Correct tests failing @ CI (#2611)

* Another attempt to correct tests failing @ CI (#2611)

* Yet another attempt to correct tests failing @ CI (#2611)

* New attempt to correct tests failing @ CI (#2611)

* Fix accidentally broken test (#2611)

* Include Radim remarks to saving models in binary FB format (#2611)

* Correcting loss bug (#2611)

* Completed correcting loss bug (#2611)

* Correcting breaking doc building bug (#2611)

* Include first batch of Michael remarks

* Refactoring SaveFacebookFormatRoundtripModelToModelTest according to Michael remarks (#2611)

* Refactoring remaining tests according to Michael remarks (#2611)

* Cleaning up the test refactoring (#2611)

* Refactoring handling tuple result from struct.unpack (#2611)

* Removing unused import (#2611)

* Refactoring variable name according to Michael review (#2611)

* Removing redundant saving in test for Facebook binary saving (#2611)

* Minimizing context manager blocks span (#2611)

* Remove obsolete comment (#2611)

* Shortening method name (#2611)

* Moving model parameters to _check_roundtrip function (#2611)

* Finished moving model parameters to _check_roundtrip function (#2611)

* Clean-up FT_HOME behaviour (#2611)

* Simplifying vectors equality check (#2611)

* Unifying testing method names (#2611)

* Refactoring _create_and_save_fb_model method name (#2611)

* Refactoring test names (#2611)

* Refactoring flake8 errors (#2611)

* Correcting fasttext invocation handling (#2611)

* Removing _parse_wordvectors function (#2611)

* Correcting whitespace and simplifying test assertion (#2611)

* Removing redundant anonymous variable (#2611)

* Moving assertion outside of a context manager (#2611)

* Function rename (#2611)

* Cleaning doc strings and comments in FB binary format saving functionality (#2611)

* Cleaning doc strings in end-user API for FB binary format saving (#2611)

* Correcting FT_CMD execution in SaveFacebookByteIdentityTest (#2611)
  • Loading branch information
lopusz authored and mpenkov committed Jan 23, 2020
1 parent fbc7d09 commit 4d22327
Show file tree
Hide file tree
Showing 3 changed files with 571 additions and 29 deletions.
322 changes: 313 additions & 9 deletions gensim/models/_fasttext_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,37 +41,53 @@

_END_OF_WORD_MARKER = b'\x00'

# FastText dictionary data structure holds elements of type `entry` which can have `entry_type`
# either `word` (0 :: int8) or `label` (1 :: int8). Here we deal with unsupervised case only
# so we want `word` type.
# See https://github.com/facebookresearch/fastText/blob/master/src/dictionary.h

_DICT_WORD_ENTRY_TYPE_MARKER = b'\x00'


logger = logging.getLogger(__name__)

_FASTTEXT_FILEFORMAT_MAGIC = 793712314
# Constants for FastText vesrion and FastText file format magic (both int32)
# https://github.com/facebookresearch/fastText/blob/master/src/fasttext.cc#L25

_FASTTEXT_VERSION = np.int32(12)
_FASTTEXT_FILEFORMAT_MAGIC = np.int32(793712314)


# _NEW_HEADER_FORMAT is constructed on the basis of args::save method, see
# https://github.com/facebookresearch/fastText/blob/master/src/args.cc

_NEW_HEADER_FORMAT = [
('dim', 'i'),
('ws', 'i'),
('epoch', 'i'),
('min_count', 'i'),
('neg', 'i'),
('_', 'i'),
('word_ngrams', 'i'), # Unused in loading
('loss', 'i'),
('model', 'i'),
('bucket', 'i'),
('minn', 'i'),
('maxn', 'i'),
('_', 'i'),
('lr_update_rate', 'i'), # Unused in loading
('t', 'd'),
]

_OLD_HEADER_FORMAT = [
('epoch', 'i'),
('min_count', 'i'),
('neg', 'i'),
('_', 'i'),
('word_ngrams', 'i'), # Unused in loading
('loss', 'i'),
('model', 'i'),
('bucket', 'i'),
('minn', 'i'),
('maxn', 'i'),
('_', 'i'),
('lr_update_rate', 'i'), # Unused in loading
('t', 'd'),
]

Expand All @@ -93,6 +109,7 @@ def _yield_field_names():
yield 'nwords'
yield 'vectors_ngrams'
yield 'hidden_output'
yield 'ntokens'


_FIELD_NAMES = sorted(set(_yield_field_names()))
Expand Down Expand Up @@ -168,6 +185,7 @@ def _load_vocab(fin, new_format, encoding='utf-8'):
The loaded vocabulary. Keys are words, values are counts.
The vocabulary size.
The number of words.
The number of tokens.
"""
vocab_size, nwords, nlabels = _struct_unpack(fin, '@3i')

Expand All @@ -176,7 +194,8 @@ def _load_vocab(fin, new_format, encoding='utf-8'):
raise NotImplementedError("Supervised fastText models are not supported")
logger.info("loading %s words for fastText model from %s", vocab_size, fin.name)

_struct_unpack(fin, '@1q') # number of tokens
ntokens = _struct_unpack(fin, '@q')[0] # number of tokens

if new_format:
pruneidx_size, = _struct_unpack(fin, '@q')

Expand Down Expand Up @@ -205,7 +224,7 @@ def _load_vocab(fin, new_format, encoding='utf-8'):
for j in range(pruneidx_size):
_struct_unpack(fin, '@2i')

return raw_vocab, vocab_size, nwords
return raw_vocab, vocab_size, nwords, ntokens


def _load_matrix(fin, new_format=True):
Expand Down Expand Up @@ -315,11 +334,12 @@ def load(fin, encoding='utf-8', full_model=True):

header_spec = _NEW_HEADER_FORMAT if new_format else _OLD_HEADER_FORMAT
model = {name: _struct_unpack(fin, fmt)[0] for (name, fmt) in header_spec}

if not new_format:
model.update(dim=magic, ws=version)

raw_vocab, vocab_size, nwords = _load_vocab(fin, new_format, encoding=encoding)
model.update(raw_vocab=raw_vocab, vocab_size=vocab_size, nwords=nwords)
raw_vocab, vocab_size, nwords, ntokens = _load_vocab(fin, new_format, encoding=encoding)
model.update(raw_vocab=raw_vocab, vocab_size=vocab_size, nwords=nwords, ntokens=ntokens)

vectors_ngrams = _load_matrix(fin, new_format=new_format)

Expand Down Expand Up @@ -366,5 +386,289 @@ def _backslashreplace_backport(ex):
return text, end


def _sign_model(fout):
"""
Write signature of the file in Facebook's native fastText `.bin` format
to the binary output stream `fout`. Signature includes magic bytes and version.
Name mimics original C++ implementation, see
[FastText::signModel](https://github.com/facebookresearch/fastText/blob/master/src/fasttext.cc)
Parameters
----------
fout: writeable binary stream
"""
fout.write(_FASTTEXT_FILEFORMAT_MAGIC.tobytes())
fout.write(_FASTTEXT_VERSION.tobytes())


def _conv_field_to_bytes(field_value, field_type):
"""
Auxiliary function that converts `field_value` to bytes based on request `field_type`,
for saving to the binary file.
Parameters
----------
field_value: numerical
contains arguments of the string and start/end indexes of the bad portion.
field_type: str
currently supported `field_types` are `i` for 32-bit integer and `d` for 64-bit float
"""
if field_type == 'i':
return (np.int32(field_value).tobytes())
elif field_type == 'd':
return (np.float64(field_value).tobytes())
else:
raise NotImplementedError('Currently conversion to "%s" type is not implemmented.' % field_type)


def _get_field_from_model(model, field):
"""
Extract `field` from `model`.
Parameters
----------
model: gensim.models.fasttext.FastText
model from which `field` is extracted
field: str
requested field name, fields are listed in the `_NEW_HEADER_FORMAT` list
"""
if field == 'bucket':
return model.trainables.bucket
elif field == 'dim':
return model.vector_size
elif field == 'epoch':
return model.epochs
elif field == 'loss':
# `loss` => hs: 1, ns: 2, softmax: 3, ova-vs-all: 4
# ns = negative sampling loss (default)
# hs = hierarchical softmax loss
# softmax = softmax loss
# one-vs-all = one vs all loss (supervised)
if model.hs == 1:
return 1
elif model.hs == 0:
return 2
elif model.hs == 0 and model.negative == 0:
return 1
elif field == 'maxn':
return model.wv.max_n
elif field == 'minn':
return model.wv.min_n
elif field == 'min_count':
return model.vocabulary.min_count
elif field == 'model':
# `model` => cbow:1, sg:2, sup:3
# cbow = continous bag of words (default)
# sg = skip-gram
# sup = supervised
return 2 if model.sg == 1 else 1
elif field == 'neg':
return model.negative
elif field == 't':
return model.vocabulary.sample
elif field == 'word_ngrams':
# This is skipped in gensim loading setting, using the default from FB C++ code
return 1
elif field == 'ws':
return model.window
elif field == 'lr_update_rate':
# This is skipped in gensim loading setting, using the default from FB C++ code
return 100
else:
msg = 'Extraction of header field "' + field + '" from Gensim FastText object not implemmented.'
raise NotImplementedError(msg)


def _args_save(fout, model, fb_fasttext_parameters):
"""
Saves header with `model` parameters to the binary stream `fout` containing a model in the Facebook's
native fastText `.bin` format.
Name mimics original C++ implementation, see
[Args::save](https://github.com/facebookresearch/fastText/blob/master/src/args.cc)
Parameters
----------
fout: writeable binary stream
stream to which model is saved
model: gensim.models.fasttext.FastText
saved model
fb_fasttext_parameters: dictionary
dictionary contain parameters containing `lr_update_rate`, `word_ngrams`
unused by gensim implementation, so they have to be provided externally
"""
for field, field_type in _NEW_HEADER_FORMAT:
if field in fb_fasttext_parameters:
field_value = fb_fasttext_parameters[field]
else:
field_value = _get_field_from_model(model, field)
fout.write(_conv_field_to_bytes(field_value, field_type))


def _dict_save(fout, model, encoding):
"""
Saves the dictionary from `model` to the to the binary stream `fout` containing a model in the Facebook's
native fastText `.bin` format.
Name mimics the original C++ implementation
[Dictionary::save](https://github.com/facebookresearch/fastText/blob/master/src/dictionary.cc)
Parameters
----------
fout: writeable binary stream
stream to which the dictionary from the model is saved
model: gensim.models.fasttext.FastText
the model that contains the dictionary to save
encoding: str
string encoding used in the output
"""

# In the FB format the dictionary can contain two types of entries, i.e.
# words and labels. The first two fields of the dictionary contain
# the dictionary size (size_) and the number of words (nwords_).
# In the unsupervised case we have only words (no labels). Hence both fields
# are equal.

fout.write(np.int32(len(model.wv.vocab)).tobytes())

fout.write(np.int32(len(model.wv.vocab)).tobytes())

# nlabels=0 <- no labels we are in unsupervised mode
fout.write(np.int32(0).tobytes())

fout.write(np.int64(model.corpus_total_words).tobytes())

# prunedidx_size_=-1, -1 value denotes no prunning index (prunning is only supported in supervised mode)
fout.write(np.int64(-1))

for word in model.wv.index2word:
word_count = model.wv.vocab[word].count
fout.write(word.encode(encoding))
fout.write(_END_OF_WORD_MARKER)
fout.write(np.int64(word_count).tobytes())
fout.write(_DICT_WORD_ENTRY_TYPE_MARKER)

# We are in unsupervised case, therefore pruned_idx is empty, so we do not need to write anything else


def _input_save(fout, model):
"""
Saves word and ngram vectors from `model` to the binary stream `fout` containing a model in
the Facebook's native fastText `.bin` format.
Corresponding C++ fastText code:
[DenseMatrix::save](https://github.com/facebookresearch/fastText/blob/master/src/densematrix.cc)
Parameters
----------
fout: writeable binary stream
stream to which the vectors are saved
model: gensim.models.fasttext.FastText
the model that contains the vectors to save
"""
vocab_n, vocab_dim = model.wv.vectors_vocab.shape
ngrams_n, ngrams_dim = model.wv.vectors_ngrams.shape

assert vocab_dim == ngrams_dim
assert vocab_n == len(model.wv.vocab)
assert ngrams_n == model.wv.bucket

fout.write(struct.pack('@2q', vocab_n + ngrams_n, vocab_dim))
fout.write(model.wv.vectors_vocab.tobytes())
fout.write(model.wv.vectors_ngrams.tobytes())


def _output_save(fout, model):
"""
Saves output layer of `model` to the binary stream `fout` containing a model in
the Facebook's native fastText `.bin` format.
Corresponding C++ fastText code:
[DenseMatrix::save](https://github.com/facebookresearch/fastText/blob/master/src/densematrix.cc)
Parameters
----------
fout: writeable binary stream
the model that contains the output layer to save
model: gensim.models.fasttext.FastText
saved model
"""
if model.hs:
hidden_output = model.trainables.syn1
if model.negative:
hidden_output = model.trainables.syn1neg

hidden_n, hidden_dim = hidden_output.shape
fout.write(struct.pack('@2q', hidden_n, hidden_dim))
fout.write(hidden_output.tobytes())


def _save_to_stream(model, fout, fb_fasttext_parameters, encoding):
"""
Saves word embeddings to binary stream `fout` using the Facebook's native fasttext `.bin` format.
Parameters
----------
fout: file name or writeable binary stream
stream to which the word embeddings are saved
model: gensim.models.fasttext.FastText
the model that contains the word embeddings to save
fb_fasttext_parameters: dictionary
dictionary contain parameters containing `lr_update_rate`, `word_ngrams`
unused by gensim implementation, so they have to be provided externally
encoding: str
encoding used in the output file
"""

_sign_model(fout)
_args_save(fout, model, fb_fasttext_parameters)
_dict_save(fout, model, encoding)
fout.write(struct.pack('@?', False)) # Save 'quant_', which is False for unsupervised models

# Save words and ngrams vectors
_input_save(fout, model)
fout.write(struct.pack('@?', False)) # Save 'quot_', which is False for unsupervised models

# Save output layers of the model
_output_save(fout, model)


def save(model, fout, fb_fasttext_parameters, encoding):
"""
Saves word embeddings to the Facebook's native fasttext `.bin` format.
Parameters
----------
fout: file name or writeable binary stream
stream to which model is saved
model: gensim.models.fasttext.FastText
saved model
fb_fasttext_parameters: dictionary
dictionary contain parameters containing `lr_update_rate`, `word_ngrams`
unused by gensim implementation, so they have to be provided externally
encoding: str
encoding used in the output file
Notes
-----
Unfortunately, there is no documentation of the Facebook's native fasttext `.bin` format
This is just reimplementation of
[FastText::saveModel](https://github.com/facebookresearch/fastText/blob/master/src/fasttext.cc)
Based on v0.9.1, more precisely commit da2745fcccb848c7a225a7d558218ee4c64d5333
Code follows the original C++ code naming.
"""

if isinstance(fout, str):
with open(fout, "wb") as fout_stream:
_save_to_stream(model, fout_stream, fb_fasttext_parameters, encoding)
else:
_save_to_stream(model, fout, fb_fasttext_parameters, encoding)


if six.PY2:
codecs.register_error('backslashreplace', _backslashreplace_backport)
Loading

0 comments on commit 4d22327

Please sign in to comment.