Skip to content

Commit

Permalink
Refactoring SaveFacebookFormatRoundtripModelToModelTest according to …
Browse files Browse the repository at this point in the history
…Michael remarks (piskvorky#2611)
  • Loading branch information
lopusz committed Jan 5, 2020
1 parent 2ed3115 commit 8e8ca1e
Showing 1 changed file with 167 additions and 170 deletions.
337 changes: 167 additions & 170 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,21 +1305,23 @@ def _run(self, fin):
self.assertTrue(np.allclose(_ARRAY, array))


class SaveFacebookFormatTest(unittest.TestCase):
def _create_and_save_test_model(self, fname, model_params):
model = FT_gensim(**model_params)
lee_data = LineSentence(datapath('lee_background.cor'))
model.build_vocab(lee_data)
model.train(lee_data, total_examples=model.corpus_count, epochs=model.epochs)
gensim.models.fasttext.save_facebook_model(model, fname)
return model
MAX_WORDVEC_COMPONENT_DIFFERENCE = 1.0e-10


def _create_and_save_test_model(fname, model_params):
model = FT_gensim(**model_params)
lee_data = LineSentence(datapath('lee_background.cor'))
model.build_vocab(lee_data)
model.train(lee_data, total_examples=model.corpus_count, epochs=model.epochs)
gensim.models.fasttext.save_facebook_model(model, fname)
return model


def calc_max_diff(v1, v2):
return np.max(np.abs(v1 - v2))


class SaveFacebookFormatRoundtripModelToModelTest(SaveFacebookFormatTest):
class SaveFacebookFormatRoundtripModelToModelTest(unittest.TestCase):
"""
This class containts tests that check the following scenario:
Expand All @@ -1331,11 +1333,9 @@ class SaveFacebookFormatRoundtripModelToModelTest(SaveFacebookFormatTest):

def _check_roundtrip_model_model(self, model_params):

MAX_WORDVEC_COMPONENT_DIFFERENCE = 1.0e-10

with temporary_file("roundtrip_model_to_model.bin") as fpath:

model_orig = self._create_and_save_test_model(fpath, model_params)
model_orig = _create_and_save_test_model(fpath, model_params)

gensim.models.fasttext.save_facebook_model(model_orig, fpath)
model_loaded = gensim.models.fasttext.load_facebook_model(fpath)
Expand Down Expand Up @@ -1366,17 +1366,19 @@ def _check_roundtrip_model_model(self, model_params):
raise e

def test_round_trip_model_model_skipgram(self):
model_params = {"size": 10, "min_count": 1, "hs": 1, "sg": 1,
"negative": 5, "seed": 42, "workers": 1}
model_params = {
"size": 10, "min_count": 1, "hs": 1, "sg": 1,
"negative": 5, "seed": 42, "workers": 1}
self._check_roundtrip_model_model(model_params)

def test_round_trip_model_model_cbow(self):
model_params = {"size": 10, "min_count": 1, "hs": 1, "sg": 0,
"negative": 5, "seed": 42, "workers": 1}
model_params = {
"size": 10, "min_count": 1, "hs": 1, "sg": 0, "negative": 5,
"seed": 42, "workers": 1}
self._check_roundtrip_model_model(model_params)


class SaveFacebookFormatRoundtripFileToFileTest(SaveFacebookFormatTest):
class SaveFacebookFormatRoundtripFileToFileTest(unittest.TestCase):
"""
Base clas for FileToFile Roundtrip tests containing comparing FB binary file functionality
"""
Expand Down Expand Up @@ -1474,163 +1476,158 @@ def _check_roundtrip_file_file(self, model_params):
self._compare_fasttext_files(fpath1, fpath2)


class SaveFacebookFormatRoundtripFileToFileGensimTest(SaveFacebookFormatRoundtripFileToFileTest):
"""
This class containts tests that check the following scenario:
+ create binary fastText file model1.bin using Gensim
+ load file model1.bin to model
+ save model to model2.bin
+ check if files model1.bin and model2.bin are identical
"""

def test_roundtrip_file_file_skipgram(self):
model_params = {
"size": 10,
"min_count": 1,
"hs": 1,
"sg": 1,
"negative": 0,
"seed": 42,
"workers": 1}
self._check_roundtrip_file_file(model_params)

def test_roundtrip_file_file_cbow(self):
model_params = {
"size": 10,
"min_count": 1,
"hs": 1,
"sg": 0,
"negative": 0,
"seed": 42,
"workers": 1}
self._check_roundtrip_file_file(model_params)


class SaveFacebookFormatRoundtripFileToFileFacebookTest(SaveFacebookFormatRoundtripFileToFileTest):
"""
This class containts tests that check the following scenario:
+ create binary fastText file model1.bin using facebook_binary
+ load file model1.bin to model
+ save model to model2.bin
+ check if files model1.bin and model2.bin are identical
Requires env. variable FT_HOME to point to location of Facebook fastText binary
"""

def _create_and_save_test_model(self, out_base_fname, model_params, fasttext_cmd):
inp_fname = datapath('lee_background.cor')

model_type = "cbow" if model_params["sg"] == 0 else "skipgram"
size = str(model_params["size"])
seed = str(model_params["seed"])

cmd = fasttext_cmd + " " + model_type + " -input " + inp_fname + \
" -output " + out_base_fname + " -dim " + size + " -seed " + seed
subprocess.run(cmd, shell=True)

def _check_roundtrip_file_file(self, model_params):
ft_home = os.environ.get("FT_HOME", None)
fasttext_cmd = os.path.join(ft_home, "fasttext")

# fasttext tool creates both *vec and *bin files, so we have to remove both, even thought *vec is unused

with temporary_file("roundtrip_file_to_file1.bin") as fpath1bin, \
temporary_file("roundtrip_file_to_file2.bin") as fpath2bin, \
temporary_file("roundtrip_file_to_file1.vec") as fpath1vec: # noqa:F841

fpath1base = fpath1bin[:-4]
self._create_and_save_test_model(fpath1base, model_params, fasttext_cmd)
model = gensim.models.fasttext.load_facebook_model(fpath1bin)
gensim.models.fasttext.save_facebook_model(model, fpath2bin)

self._compare_fasttext_files(fpath1bin, fpath2bin)
# class SaveFacebookFormatRoundtripFileToFileGensimTest(SaveFacebookFormatRoundtripFileToFileTest):
# """
# This class containts tests that check the following scenario:

def test_roundtrip_file_file_skipgram(self):
if not os.environ.get("FT_HOME", None):
self.skipTest("FT_HOME env variable not set")
else:
model_params = {"size": 10, "sg": 1, "seed": 42}
self._check_roundtrip_file_file(model_params)
# + create binary fastText file model1.bin using Gensim
# + load file model1.bin to model
# + save model to model2.bin
# + check if files model1.bin and model2.bin are identical
# """

# def test_roundtrip_file_file_skipgram(self):
# model_params = {
# "size": 10,
# "min_count": 1,
# "hs": 1,
# "sg": 1,
# "negative": 0,
# "seed": 42,
# "workers": 1}
# self._check_roundtrip_file_file(model_params)

def test_roundtrip_file_file_cbow(self):
if not os.environ.get("FT_HOME", None):
self.skipTest("FT_HOME env variable not set")
else:
model_params = {"size": 10, "sg": 0, "seed": 42}
self._check_roundtrip_file_file(model_params)
# def test_roundtrip_file_file_cbow(self):
# model_params = {
# "size": 10,
# "min_count": 1,
# "hs": 1,
# "sg": 0,
# "negative": 0,
# "seed": 42,
# "workers": 1}
# self._check_roundtrip_file_file(model_params)


class SaveFacebookFormatReadingTest(SaveFacebookFormatTest):
"""
This class containts tests that check the following scenario:
+ create fastText model
+ save file tom model.bin
+ retrieve word vectors from model.bin to stdout using fasttext Facebook utility
+ compare vectors retrieved by Facebook utility with those obtained directly from gensim model
Requires env. variable FT_HOME to point to location of Facebook fastText binary
"""

def _parse_wordvectors(self, text):
def _conv_line_to_array(line):
return np.array([float(s) for s in line.split()[1:]], dtype=np.float32)

return np.array([_conv_line_to_array(l) for l in text.splitlines()], dtype=np.float32)

def _get_wordvectors_from_fb_fastttext(self, fasttext_cmd, fasttext_fname, words):
cmd = fasttext_cmd + " print-word-vectors " + fasttext_fname
process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
words_str = '\n'.join(words)
out, err = process.communicate(input=words_str.encode("utf-8"))
return self._parse_wordvectors(out.decode("utf-8"))

def _check_load_fasttext_format(self, model_params):

ft_home = os.environ.get("FT_HOME", None)
fasttext_cmd = os.path.join(ft_home, "fasttext")

with temporary_file("load_fasttext.bin") as fpath:
model = self._create_and_save_test_model(fpath, model_params)
wv = self._get_wordvectors_from_fb_fastttext(fasttext_cmd, fpath, model.wv.index2word)

for i, w in enumerate(model.wv.index2word):
diff = calc_max_diff(wv[i, :], model.wv[w])
# Because fasttext command line prints vectors with limited accuracy
self.assertLess(diff, 1.0e-4)

def test_load_fasttext_format_cbow(self):
if not os.environ.get("FT_HOME", None):
self.skipTest("FT_HOME env variable not set")
else:
model_params = {
"size": 10,
"min_count": 1,
"hs": 1,
"sg": 0,
"negative": 5,
"seed": 42,
"workers": 1}
self._check_load_fasttext_format(model_params)

def test_load_fasttext_format_skipgram(self):
if not os.environ.get("FT_HOME", None):
self.skipTest("FT_HOME env variable not set")
else:
model_params = {
"size": 10,
"min_count": 1,
"hs": 1,
"sg": 1,
"negative": 5,
"seed": 42,
"workers": 1}
self._check_load_fasttext_format(model_params)
# class SaveFacebookFormatRoundtripFileToFileFacebookTest(unittest.TestCase):
# """
# This class containts tests that check the following scenario:

# + create binary fastText file model1.bin using facebook_binary
# + load file model1.bin to model
# + save model to model2.bin
# + check if files model1.bin and model2.bin are identical

# Requires env. variable FT_HOME to point to location of Facebook fastText binary
# """

# def _save_test_model(out_base_fname, model_params, fasttext_cmd):
# inp_fname = datapath('lee_background.cor')

# model_type = "cbow" if model_params["sg"] == 0 else "skipgram"
# size = str(model_params["size"])
# seed = str(model_params["seed"])

# cmd = fasttext_cmd + " " + model_type + " -input " + inp_fname + \
# " -output " + out_base_fname + " -dim " + size + " -seed " + seed
# subprocess.run(cmd, shell=True)

# def _check_roundtrip_file_file(self, model_params):
# ft_home = os.environ.get("FT_HOME", None)
# fasttext_cmd = os.path.join(ft_home, "fasttext")

# # fasttext tool creates both *vec and *bin files, so we have to remove both, even thought *vec is unused

# with temporary_file("roundtrip_file_to_file1.bin") as fpath1bin, \
# temporary_file("roundtrip_file_to_file2.bin") as fpath2bin, \
# temporary_file("roundtrip_file_to_file1.vec") as fpath1vec: # noqa:F841

# fpath1base = fpath1bin[:-4]
# self._create_and_save_test_model(fpath1base, model_params, fasttext_cmd)
# model = gensim.models.fasttext.load_facebook_model(fpath1bin)
# gensim.models.fasttext.save_facebook_model(model, fpath2bin)

# self._compare_fasttext_files(fpath1bin, fpath2bin)

# @unittest.skipIf(not os.environ.get("FT_HOME", None), "FT_HOME env variable not set, skipping test")
# def test_roundtrip_file_file_skipgram(self):
# model_params = {"size": 10, "sg": 1, "seed": 42}
# self._check_roundtrip_file_file(model_params)

# @unittest.skipIf(not os.environ.get("FT_HOME", None), "FT_HOME env variable not set, skipping test")
# def test_roundtrip_file_file_cbow(self):
# model_params = {"size": 10, "sg": 0, "seed": 42}
# self._check_roundtrip_file_file(model_params)


# def _parse_wordvectors(text):
# def _conv_line_to_array(line):
# return np.array([float(s) for s in line.split()[1:]], dtype=np.float32)

# return np.array([_conv_line_to_array(l) for l in text.splitlines()], dtype=np.float32)


# def _get_wordvectors_from_fb_fastttext(fasttext_cmd, fasttext_fname, words):
# cmd = fasttext_cmd + " print-word-vectors " + fasttext_fname
# process = subprocess.Popen(
# cmd, stdin=subprocess.PIPE,
# stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
# shell=True)
# words_str = '\n'.join(words)
# out, err = process.communicate(input=words_str.encode("utf-8"))
# return _parse_wordvectors(out.decode("utf-8"))


# class SaveFacebookFormatReadingTest(SaveFacebookFormatTest):
# """
# This class containts tests that check the following scenario:

# + create fastText model
# + save file tom model.bin
# + retrieve word vectors from model.bin to stdout using fasttext Facebook utility
# + compare vectors retrieved by Facebook utility with those obtained directly from gensim model

# Requires env. variable FT_HOME to point to location of Facebook fastText binary
# """

# def _check_load_fasttext_format(self, model_params):

# ft_home = os.environ.get("FT_HOME", None)
# fasttext_cmd = os.path.join(ft_home, "fasttext")

# with temporary_file("load_fasttext.bin") as fpath:
# model = self._create_and_save_test_model(fpath, model_params)
# wv = _get_wordvectors_from_fb_fastttext(fasttext_cmd, fpath, model.wv.index2word)

# for i, w in enumerate(model.wv.index2word):
# diff = calc_max_diff(wv[i, :], model.wv[w])
# # Because fasttext command line prints vectors with limited accuracy
# self.assertLess(diff, 1.0e-4)

# @unittest.skipIf(not os.environ.get("FT_HOME", None), "FT_HOME env variable not set, skipping test")
# def test_load_fasttext_format_cbow(self):
# model_params = {
# "size": 10,
# "min_count": 1,
# "hs": 1,
# "sg": 0,
# "negative": 5,
# "seed": 42,
# "workers": 1}
# self._check_load_fasttext_format(model_params)

# @unittest.skipIf(not os.environ.get("FT_HOME", None), "FT_HOME env variable not set, skipping test")
# def test_load_fasttext_format_skipgram(self):
# model_params = {
# "size": 10,
# "min_count": 1,
# "hs": 1,
# "sg": 1,
# "negative": 5,
# "seed": 42,
# "workers": 1}
# self._check_load_fasttext_format(model_params)


if __name__ == '__main__':
Expand Down

0 comments on commit 8e8ca1e

Please sign in to comment.