Skip to content

Commit

Permalink
Clean-up FT_HOME behaviour (piskvorky#2611)
Browse files Browse the repository at this point in the history
  • Loading branch information
lopusz committed Jan 8, 2020
1 parent 6bebcef commit 1eaea78
Showing 1 changed file with 17 additions and 32 deletions.
49 changes: 17 additions & 32 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@

MAX_WORDVEC_COMPONENT_DIFFERENCE = 1.0e-10

FT_HOME = os.environ.get("FT_HOME")
FT_CMD = os.path.join(FT_HOME, "fasttext") if FT_HOME else None


class LeeCorpus(object):
def __iter__(self):
Expand All @@ -59,8 +62,6 @@ def __iter__(self):
class TestFastTextModel(unittest.TestCase):

def setUp(self):
ft_home = os.environ.get('FT_HOME', None)
self.ft_path = os.path.join(ft_home, 'fasttext') if ft_home else None
self.test_model_file = datapath('lee_fasttext.bin')
self.test_model = gensim.models.fasttext.load_facebook_model(self.test_model_file)
self.test_new_model_file = datapath('lee_fasttext_new.bin')
Expand Down Expand Up @@ -814,13 +815,10 @@ def compare_with_wrapper(self, model_gensim, model_wrapper):
# this limit can be increased when using Cython code
self.assertGreaterEqual(overlap_count, 2)

@unittest.skipIf(not FT_HOME, "FT_HOME env variable not set, skipping test")
def test_cbow_hs_against_wrapper(self):
if self.ft_path is None:
logger.info("FT_HOME env variable not set, skipping test")
return

tmpf = get_tmpfile('gensim_fasttext.tst')
model_wrapper = FT_wrapper.train(ft_path=self.ft_path, corpus_file=datapath('lee_background.cor'),
model_wrapper = FT_wrapper.train(ft_path=FT_CMD, corpus_file=datapath('lee_background.cor'),
output_file=tmpf, model='cbow', size=50, alpha=0.05, window=5, min_count=5,
word_ngrams=1,
loss='hs', sample=1e-3, negative=0, iter=5, min_n=3, max_n=6, sorted_vocab=1,
Expand All @@ -837,13 +835,11 @@ def test_cbow_hs_against_wrapper(self):
self.assertFalse((orig0 == model_gensim.wv.vectors[0]).all()) # vector should vary after training
self.compare_with_wrapper(model_gensim, model_wrapper)

@unittest.skipIf(not FT_HOME, "FT_HOME env variable not set, skipping test")
def test_sg_hs_against_wrapper(self):
if self.ft_path is None:
logger.info("FT_HOME env variable not set, skipping test")
return

tmpf = get_tmpfile('gensim_fasttext.tst')
model_wrapper = FT_wrapper.train(ft_path=self.ft_path, corpus_file=datapath('lee_background.cor'),
model_wrapper = FT_wrapper.train(ft_path=FT_CMD, corpus_file=datapath('lee_background.cor'),
output_file=tmpf, model='skipgram', size=50, alpha=0.025, window=5,
min_count=5, word_ngrams=1,
loss='hs', sample=1e-3, negative=0, iter=5, min_n=3, max_n=6, sorted_vocab=1,
Expand Down Expand Up @@ -1322,7 +1318,7 @@ def calc_max_diff(v1, v2):
class SaveFacebookFormatModelTest(unittest.TestCase):

def _check_roundtrip(self, sg):
model_params = {
model_params = {
"sg": sg,
"size": 10,
"min_count": 1,
Expand Down Expand Up @@ -1409,18 +1405,19 @@ def test_roundtrip_file_file_cbow(self):
self._check_roundtrip_file_file(sg=0)


def _save_test_model(out_base_fname, model_params, fasttext_cmd):
def _save_test_model(out_base_fname, model_params):
inp_fname = datapath('lee_background.cor')

model_type = "cbow" if model_params["sg"] == 0 else "skipgram"
size = str(model_params["size"])
seed = str(model_params["seed"])

cmd = fasttext_cmd + " " + model_type + " -input " + inp_fname + \
" -output " + out_base_fname + " -dim " + size + " -seed " + seed
cmd = FT_CMD + " " + model_type + " -input " + inp_fname + \
" -output " + out_base_fname + " -dim " + size + " -seed " + seed
subprocess.run(cmd, shell=True)


@unittest.skipIf(not FT_HOME, "FT_HOME env variable not set, skipping test")
class SaveFacebookFormatFileFastTextTest(unittest.TestCase):
"""
This class containts tests that check the following scenario:
Expand All @@ -1429,14 +1426,10 @@ class SaveFacebookFormatFileFastTextTest(unittest.TestCase):
+ load file model1.bin to variable `model`
+ save `model` to model2.bin using gensim
+ check if files model1.bin and model2.bin are byte-identical
Requires env. variable FT_HOME to point to location of Facebook fastText binary
"""

def _check_roundtrip_file_file(self, sg):
model_params = {"size": 10, "sg": sg, "seed": 42}
ft_home = os.environ.get("FT_HOME", None)
fasttext_cmd = os.path.join(ft_home, "fasttext")

# fasttext tool creates both *vec and *bin files, so we have to remove both, even thought *vec is unused

Expand All @@ -1445,16 +1438,14 @@ def _check_roundtrip_file_file(self, sg):
temporary_file("roundtrip_file_to_file1.vec") as fpath1vec: # noqa:F841

fpath1base = fpath1bin[:-4]
_save_test_model(fpath1base, model_params, fasttext_cmd)
_save_test_model(fpath1base, model_params)
model = gensim.models.fasttext.load_facebook_model(fpath1bin)
gensim.models.fasttext.save_facebook_model(model, fpath2bin)
self.assertEqual(_read_binary_file(fpath1bin), _read_binary_file(fpath2bin))

@unittest.skipIf(not os.environ.get("FT_HOME", None), "FT_HOME env variable not set, skipping test")
def test_roundtrip_file_file_skipgram(self):
self._check_roundtrip_file_file(sg=1)

@unittest.skipIf(not os.environ.get("FT_HOME", None), "FT_HOME env variable not set, skipping test")
def test_roundtrip_file_file_cbow(self):
self._check_roundtrip_file_file(sg=0)

Expand All @@ -1466,8 +1457,8 @@ def _conv_line_to_array(line):
return np.array([_conv_line_to_array(l) for l in text.splitlines()], dtype=np.float32)


def _read_wordvectors_using_fasttext(fasttext_cmd, fasttext_fname, words):
cmd = fasttext_cmd + " print-word-vectors " + fasttext_fname
def _read_wordvectors_using_fasttext(fasttext_fname, words):
cmd = FT_CMD + " print-word-vectors " + fasttext_fname
process = subprocess.Popen(
cmd, stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
Expand All @@ -1478,6 +1469,7 @@ def _read_wordvectors_using_fasttext(fasttext_cmd, fasttext_fname, words):
return _parse_wordvectors(out.decode("utf-8"))


@unittest.skipIf(not os.environ.get("FT_HOME", None), "FT_HOME env variable not set, skipping test")
class SaveFacebookFormatReadingTest(unittest.TestCase):
"""
This class containts tests that check the following scenario:
Expand All @@ -1486,8 +1478,6 @@ class SaveFacebookFormatReadingTest(unittest.TestCase):
+ save file to model.bin
+ retrieve word vectors from model.bin using fasttext Facebook utility
+ compare vectors retrieved by Facebook utility with those obtained directly from gensim model
Requires env. variable FT_HOME to point to location of Facebook fastText binary
"""

def _check_load_fasttext_format(self, sg):
Expand All @@ -1500,23 +1490,18 @@ def _check_load_fasttext_format(self, sg):
"seed": 42,
"workers": 1}

ft_home = os.environ.get("FT_HOME", None)
fasttext_cmd = os.path.join(ft_home, "fasttext")

with temporary_file("load_fasttext.bin") as fpath:
model = _create_and_save_test_model(fpath, model_params)
wv = _read_wordvectors_using_fasttext(fasttext_cmd, fpath, model.wv.index2word)
wv = _read_wordvectors_using_fasttext(fpath, model.wv.index2word)

for i, w in enumerate(model.wv.index2word):
diff = calc_max_diff(wv[i, :], model.wv[w])
# Because fasttext command line prints vectors with limited accuracy
self.assertLess(diff, 1.0e-4)

@unittest.skipIf(not os.environ.get("FT_HOME", None), "FT_HOME env variable not set, skipping test")
def test_load_fasttext_format_cbow(self):
self._check_load_fasttext_format(sg=0)

@unittest.skipIf(not os.environ.get("FT_HOME", None), "FT_HOME env variable not set, skipping test")
def test_load_fasttext_format_skipgram(self):
self._check_load_fasttext_format(sg=1)

Expand Down

0 comments on commit 1eaea78

Please sign in to comment.