-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix phrases load, for backward compatibility #1758
Changes from 4 commits
523ab11
aecc95d
6297bf4
f4fce12
198fdf7
e023ac7
cc623ca
63fe8c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,17 +8,29 @@ | |
""" | ||
|
||
|
||
import contextlib | ||
import logging | ||
import unittest | ||
import os | ||
import sys | ||
import shutil | ||
import tempfile | ||
import unittest | ||
|
||
import six | ||
|
||
from gensim import utils | ||
from gensim.models.phrases import SentenceAnalyzer, Phrases, Phraser, pseudocorpus | ||
from gensim.models.phrases import SentenceAnalyzer, Phrases, Phraser | ||
from gensim.models.phrases import pseudocorpus, original_scorer | ||
from gensim.test.utils import common_texts | ||
|
||
if sys.version_info[0] >= 3: | ||
unicode = str | ||
|
||
@contextlib.contextmanager | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks useful, please move it to |
||
def temporary_file(name): | ||
# note : when dropping python2.7 support, we can use tempfile.TemporaryDirectory | ||
tmp = tempfile.mkdtemp() | ||
try: | ||
yield os.path.join(tmp, name) | ||
finally: | ||
shutil.rmtree(tmp, ignore_errors=True) | ||
|
||
|
||
class TestUtils(unittest.TestCase): | ||
|
@@ -230,7 +242,7 @@ def testEncoding(self): | |
self.assertEqual(self.bigram_unicode[self.sentences[1]], expected) | ||
|
||
transformed = ' '.join(self.bigram_utf8[self.sentences[1]]) | ||
self.assertTrue(isinstance(transformed, unicode)) | ||
self.assertTrue(isinstance(transformed, six.text_type)) | ||
|
||
|
||
# scorer for testCustomScorer | ||
|
@@ -335,15 +347,15 @@ def testPruning(self): | |
# endclass TestPhrasesModel | ||
|
||
|
||
class TestPhrasesScoringPersistence(PhrasesData, unittest.TestCase): | ||
class TestPhrasesPersistence(PhrasesData, unittest.TestCase): | ||
|
||
def testSaveLoadCustomScorer(self): | ||
""" saving and loading a Phrases object with a custom scorer """ | ||
|
||
try: | ||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phrases(self.sentences, min_count=1, threshold=.001, scoring=dumb_scorer) | ||
bigram.save("test_phrases_testSaveLoadCustomScorer_temp_save.pkl") | ||
bigram_loaded = Phrases.load("test_phrases_testSaveLoadCustomScorer_temp_save.pkl") | ||
bigram.save(fpath) | ||
bigram_loaded = Phrases.load(fpath) | ||
seen_scores = [] | ||
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] | ||
for phrase, score in bigram_loaded.export_phrases(test_sentences): | ||
|
@@ -352,17 +364,13 @@ def testSaveLoadCustomScorer(self): | |
assert all(seen_scores) # all scores 1 | ||
assert len(seen_scores) == 3 # 'graph minors' and 'survey human' and 'interface system' | ||
|
||
finally: | ||
if os.path.exists("test_phrases_testSaveLoadCustomScorer_temp_save.pkl"): | ||
os.remove("test_phrases_testSaveLoadCustomScorer_temp_save.pkl") | ||
|
||
def testSaveLoad(self): | ||
""" Saving and loading a Phrases object.""" | ||
|
||
try: | ||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phrases(self.sentences, min_count=1, threshold=1) | ||
bigram.save("test_phrases_testSaveLoad_temp_save.pkl") | ||
bigram_loaded = Phrases.load("test_phrases_testSaveLoad_temp_save.pkl") | ||
bigram.save(fpath) | ||
bigram_loaded = Phrases.load(fpath) | ||
seen_scores = set() | ||
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] | ||
for phrase, score in bigram_loaded.export_phrases(test_sentences): | ||
|
@@ -373,19 +381,15 @@ def testSaveLoad(self): | |
3.444 # score for human interface | ||
]) | ||
|
||
finally: | ||
if os.path.exists("test_phrases_testSaveLoad_temp_save.pkl"): | ||
os.remove("test_phrases_testSaveLoad_temp_save.pkl") | ||
|
||
def testSaveLoadStringScoring(self): | ||
""" Saving and loading a Phrases object with a string scoring parameter. | ||
This should ensure backwards compatibility with the previous version of Phrases""" | ||
|
||
try: | ||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phrases(self.sentences, min_count=1, threshold=1) | ||
bigram.scoring = "default" | ||
bigram.save("test_phrases_testSaveLoadStringScoring_temp_save.pkl") | ||
bigram_loaded = Phrases.load("test_phrases_testSaveLoadStringScoring_temp_save.pkl") | ||
bigram.save(fpath) | ||
bigram_loaded = Phrases.load(fpath) | ||
seen_scores = set() | ||
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] | ||
for phrase, score in bigram_loaded.export_phrases(test_sentences): | ||
|
@@ -396,19 +400,15 @@ def testSaveLoadStringScoring(self): | |
3.444 # score for human interface | ||
]) | ||
|
||
finally: | ||
if os.path.exists("test_phrases_testSaveLoadStringScoring_temp_save.pkl"): | ||
os.remove("test_phrases_testSaveLoadStringScoring_temp_save.pkl") | ||
|
||
def testSaveLoadNoScoring(self): | ||
""" Saving and loading a Phrases object with no scoring parameter. | ||
This should ensure backwards compatibility with old versions of Phrases""" | ||
|
||
try: | ||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phrases(self.sentences, min_count=1, threshold=1) | ||
del(bigram.scoring) | ||
bigram.save("test_phrases_testSaveLoadNoScoring_temp_save.pkl") | ||
bigram_loaded = Phrases.load("test_phrases_testSaveLoadNoScoring_temp_save.pkl") | ||
bigram.save(fpath) | ||
bigram_loaded = Phrases.load(fpath) | ||
seen_scores = set() | ||
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] | ||
for phrase, score in bigram_loaded.export_phrases(test_sentences): | ||
|
@@ -419,9 +419,78 @@ def testSaveLoadNoScoring(self): | |
3.444 # score for human interface | ||
]) | ||
|
||
finally: | ||
if os.path.exists("test_phrases_testSaveLoadNoScoring_temp_save.pkl"): | ||
os.remove("test_phrases_testSaveLoadNoScoring_temp_save.pkl") | ||
def testSaveLoadNoCommonTerms(self): | ||
""" Saving and loading a Phrases objects without common_terms | ||
This should ensure backwards compatibility with old versions of Phrases""" | ||
|
||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phrases(self.sentences, min_count=1, threshold=1) | ||
del(bigram.common_terms) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better solution - checkout to the previous version, tran+save model and try to load old model here (in your current implementation, we possibly skip other bugs, that can be hidden due to the current bug) FYI - |
||
bigram.save(fpath) | ||
bigram_loaded = Phrases.load(fpath) | ||
self.assertEqual(bigram_loaded.common_terms, frozenset()) | ||
# can make a phraser, cf #1751 | ||
phraser = Phraser(bigram_loaded) # does not raise | ||
phraser["some terms"] # does not raise | ||
|
||
|
||
class TestPhraserPersistence(PhrasesData, unittest.TestCase): | ||
|
||
def testSaveLoadCustomScorer(self): | ||
"""Saving and loading a Phraser object with a custom scorer """ | ||
|
||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phraser( | ||
Phrases(self.sentences, min_count=1, threshold=.001, scoring=dumb_scorer)) | ||
bigram.save(fpath) | ||
bigram_loaded = Phraser.load(fpath) | ||
# we do not much with scoring, just verify its the one expected | ||
self.assertEqual(bigram_loaded.scoring, dumb_scorer) | ||
|
||
def testSaveLoad(self): | ||
""" Saving and loading a Phraser object.""" | ||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phraser(Phrases(self.sentences, min_count=1, threshold=1)) | ||
bigram.save(fpath) | ||
bigram_loaded = Phraser.load(fpath) | ||
self.assertEqual( | ||
bigram_loaded[['graph', 'minors', 'survey', 'human', 'interface', 'system']], | ||
['graph_minors', 'survey', 'human_interface', 'system']) | ||
|
||
def testSaveLoadStringScoring(self): | ||
""" Saving and loading a Phraser object with a string scoring parameter. | ||
This should ensure backwards compatibility with the previous version of Phraser""" | ||
|
||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phraser(Phrases(self.sentences, min_count=1, threshold=1)) | ||
bigram.scoring = "default" | ||
bigram.save(fpath) | ||
bigram_loaded = Phraser.load(fpath) | ||
# we do not much with scoring, just verify its the one expected | ||
self.assertEqual(bigram_loaded.scoring, original_scorer) | ||
|
||
def testSaveLoadNoScoring(self): | ||
""" Saving and loading a Phraser object with no scoring parameter. | ||
This should ensure backwards compatibility with old versions of Phraser""" | ||
|
||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phraser(Phrases(self.sentences, min_count=1, threshold=1)) | ||
del(bigram.scoring) | ||
bigram.save(fpath) | ||
bigram_loaded = Phraser.load(fpath) | ||
# we do not much with scoring, just verify its the one expected | ||
self.assertEqual(bigram_loaded.scoring, original_scorer) | ||
|
||
def testSaveLoadNoCommonTerms(self): | ||
""" Saving and loading a Phraser objects without common_terms | ||
This should ensure backwards compatibility with old versions of Phraser""" | ||
|
||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phraser(Phrases(self.sentences, min_count=1, threshold=1)) | ||
del(bigram.common_terms) | ||
bigram.save(fpath) | ||
bigram_loaded = Phraser.load(fpath) | ||
self.assertEqual(bigram_loaded.common_terms, frozenset()) | ||
|
||
|
||
class TestPhraserModel(PhrasesData, PhrasesCommon, unittest.TestCase): | ||
|
@@ -486,7 +555,7 @@ def testEncoding(self): | |
self.assertEqual(self.bigram_unicode[self.sentences[1]], expected) | ||
|
||
transformed = ' '.join(self.bigram_utf8[self.sentences[1]]) | ||
self.assertTrue(isinstance(transformed, unicode)) | ||
self.assertTrue(isinstance(transformed, six.text_type)) | ||
|
||
def testMultipleBigramsSingleEntry(self): | ||
""" a single entry should produce multiple bigrams. """ | ||
|
@@ -593,7 +662,7 @@ def testEncoding(self): | |
self.assertEqual(self.bigram_unicode[self.sentences[1]], expected) | ||
|
||
transformed = ' '.join(self.bigram_utf8[self.sentences[1]]) | ||
self.assertTrue(isinstance(transformed, unicode)) | ||
self.assertTrue(isinstance(transformed, six.text_type)) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: we use 140 char limits for code, no need to split this call into several lines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
argh travis failed, it was 120 not 140 @menshikh-iv, you joker ! ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexgarel oh, sorry, my mistake (120 instead of 80, not 140)