-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix phrases load, for backward compatibility #1758
Changes from 3 commits
523ab11
aecc95d
6297bf4
f4fce12
198fdf7
e023ac7
cc623ca
63fe8c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -178,7 +178,61 @@ def analyze_sentence(self, sentence, threshold, common_terms, scorer): | |
yield (word, None) | ||
|
||
|
||
class Phrases(SentenceAnalyzer, interfaces.TransformationABC): | ||
class PhrasesTransformation(interfaces.TransformationABC): | ||
|
||
@classmethod | ||
def load(cls, *args, **kwargs): | ||
""" | ||
Load a previously saved Phrases class. Handles backwards compatibility from | ||
older Phrases versions which did not support pluggable scoring functions. | ||
Otherwise, relies on utils.load | ||
""" | ||
|
||
# for python 2 and 3 compatibility. basestring is used to check if model.scoring is a string | ||
try: | ||
basestring | ||
except NameError: | ||
basestring = str | ||
|
||
model = super(PhrasesTransformation, cls).load(*args, **kwargs) | ||
# update older models | ||
# if no scoring parameter, use default scoring | ||
if not hasattr(model, 'scoring'): | ||
logger.info('older version of %s loaded without scoring function', cls.__name__) | ||
logger.info('setting pluggable scoring method to original_scorer for compatibility') | ||
model.scoring = original_scorer | ||
# if there is a scoring parameter, and it's a text value, load the proper scoring function | ||
if hasattr(model, 'scoring'): | ||
if isinstance(model.scoring, basestring): | ||
if model.scoring == 'default': | ||
logger.info( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: we use 140 char limits for code, no need to split this call into several lines There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. argh travis failed, it was 120 not 140 @menshikh-iv, you joker ! ;-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alexgarel oh, sorry, my mistake (120 instead of 80, not 140) |
||
'older version of %s loaded with "default" scoring parameter', | ||
cls.__name__) | ||
logger.info( | ||
'setting scoring method to original_scorer pluggable scoring method ' + | ||
'for compatibility') | ||
model.scoring = original_scorer | ||
elif model.scoring == 'npmi': | ||
logger.info( | ||
'older version of %s loaded with "npmi" scoring parameter', | ||
cls.__name__) | ||
logger.info( | ||
'setting scoring method to npmi_scorer pluggable scoring method ' + | ||
'for compatibility') | ||
model.scoring = npmi_scorer | ||
else: | ||
raise ValueError( | ||
'failed to load %s model with unknown scoring setting %s' % | ||
(cls.__name__, model.scoring)) | ||
# if there is non common_terms attribute, inizialize | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo - |
||
if not hasattr(model, "common_terms"): | ||
logger.info('older version of %s loaded without common_terms attribute', cls.__name__) | ||
logger.info('setting common_terms to empty set') | ||
model.common_terms = frozenset() | ||
return model | ||
|
||
|
||
class Phrases(SentenceAnalyzer, PhrasesTransformation): | ||
""" | ||
Detect phrases, based on collected collocation counts. Adjacent words that appear | ||
together more frequently than expected are joined together with the `_` character. | ||
|
@@ -461,41 +515,6 @@ def __getitem__(self, sentence): | |
|
||
return [utils.to_unicode(w) for w in new_s] | ||
|
||
@classmethod | ||
def load(cls, *args, **kwargs): | ||
""" | ||
Load a previously saved Phrases class. Handles backwards compatibility from | ||
older Phrases versions which did not support pluggable scoring functions. Otherwise, relies on utils.load | ||
""" | ||
|
||
# for python 2 and 3 compatibility. basestring is used to check if model.scoring is a string | ||
try: | ||
basestring | ||
except NameError: | ||
basestring = str | ||
|
||
model = super(Phrases, cls).load(*args, **kwargs) | ||
# update older models | ||
# if no scoring parameter, use default scoring | ||
if not hasattr(model, 'scoring'): | ||
logger.info('older version of Phrases loaded without scoring function') | ||
logger.info('setting pluggable scoring method to original_scorer for compatibility') | ||
model.scoring = original_scorer | ||
# if there is a scoring parameter, and it's a text value, load the proper scoring function | ||
if hasattr(model, 'scoring'): | ||
if isinstance(model.scoring, basestring): | ||
if model.scoring == 'default': | ||
logger.info('older version of Phrases loaded with "default" scoring parameter') | ||
logger.info('setting scoring method to original_scorer pluggable scoring method for compatibility') | ||
model.scoring = original_scorer | ||
elif model.scoring == 'npmi': | ||
logger.info('older version of Phrases loaded with "npmi" scoring parameter') | ||
logger.info('setting scoring method to npmi_scorer pluggable scoring method for compatibility') | ||
model.scoring = npmi_scorer | ||
else: | ||
raise ValueError('failed to load Phrases model with unknown scoring setting %s' % (model.scoring)) | ||
return model | ||
|
||
|
||
# these two built-in scoring methods don't cast everything to float because the casting is done in the call | ||
# to the scoring method in __getitem__ and export_phrases. | ||
|
@@ -530,7 +549,7 @@ def pseudocorpus(source_vocab, sep, common_terms=frozenset()): | |
yield components | ||
|
||
|
||
class Phraser(SentenceAnalyzer, interfaces.TransformationABC): | ||
class Phraser(SentenceAnalyzer, PhrasesTransformation): | ||
""" | ||
Minimal state & functionality to apply results of a Phrases model to tokens. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,19 +8,33 @@ | |
""" | ||
|
||
|
||
import contextlib | ||
import logging | ||
import unittest | ||
import os | ||
import sys | ||
import shutil | ||
import tempfile | ||
import unittest | ||
|
||
from gensim import utils | ||
from gensim.models.phrases import SentenceAnalyzer, Phrases, Phraser, pseudocorpus | ||
from gensim.models.phrases import SentenceAnalyzer, Phrases, Phraser | ||
from gensim.models.phrases import pseudocorpus, original_scorer | ||
from gensim.test.utils import common_texts | ||
|
||
if sys.version_info[0] >= 3: | ||
unicode = str | ||
|
||
|
||
@contextlib.contextmanager | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks useful, please move it to |
||
def temporary_file(name): | ||
# note : when dropping python2.7 support, we can use tempfile.TemporaryDirectory | ||
tmp = tempfile.mkdtemp() | ||
try: | ||
yield os.path.join(tmp, name) | ||
finally: | ||
shutil.rmtree(tmp, ignore_errors=True) | ||
|
||
|
||
class TestUtils(unittest.TestCase): | ||
|
||
def test_pseudocorpus_no_common_terms(self): | ||
|
@@ -335,15 +349,15 @@ def testPruning(self): | |
# endclass TestPhrasesModel | ||
|
||
|
||
class TestPhrasesScoringPersistence(PhrasesData, unittest.TestCase): | ||
class TestPhrasesPersistence(PhrasesData, unittest.TestCase): | ||
|
||
def testSaveLoadCustomScorer(self): | ||
""" saving and loading a Phrases object with a custom scorer """ | ||
|
||
try: | ||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phrases(self.sentences, min_count=1, threshold=.001, scoring=dumb_scorer) | ||
bigram.save("test_phrases_testSaveLoadCustomScorer_temp_save.pkl") | ||
bigram_loaded = Phrases.load("test_phrases_testSaveLoadCustomScorer_temp_save.pkl") | ||
bigram.save(fpath) | ||
bigram_loaded = Phrases.load(fpath) | ||
seen_scores = [] | ||
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] | ||
for phrase, score in bigram_loaded.export_phrases(test_sentences): | ||
|
@@ -352,17 +366,13 @@ def testSaveLoadCustomScorer(self): | |
assert all(seen_scores) # all scores 1 | ||
assert len(seen_scores) == 3 # 'graph minors' and 'survey human' and 'interface system' | ||
|
||
finally: | ||
if os.path.exists("test_phrases_testSaveLoadCustomScorer_temp_save.pkl"): | ||
os.remove("test_phrases_testSaveLoadCustomScorer_temp_save.pkl") | ||
|
||
def testSaveLoad(self): | ||
""" Saving and loading a Phrases object.""" | ||
|
||
try: | ||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phrases(self.sentences, min_count=1, threshold=1) | ||
bigram.save("test_phrases_testSaveLoad_temp_save.pkl") | ||
bigram_loaded = Phrases.load("test_phrases_testSaveLoad_temp_save.pkl") | ||
bigram.save(fpath) | ||
bigram_loaded = Phrases.load(fpath) | ||
seen_scores = set() | ||
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] | ||
for phrase, score in bigram_loaded.export_phrases(test_sentences): | ||
|
@@ -373,19 +383,15 @@ def testSaveLoad(self): | |
3.444 # score for human interface | ||
]) | ||
|
||
finally: | ||
if os.path.exists("test_phrases_testSaveLoad_temp_save.pkl"): | ||
os.remove("test_phrases_testSaveLoad_temp_save.pkl") | ||
|
||
def testSaveLoadStringScoring(self): | ||
""" Saving and loading a Phrases object with a string scoring parameter. | ||
This should ensure backwards compatibility with the previous version of Phrases""" | ||
|
||
try: | ||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phrases(self.sentences, min_count=1, threshold=1) | ||
bigram.scoring = "default" | ||
bigram.save("test_phrases_testSaveLoadStringScoring_temp_save.pkl") | ||
bigram_loaded = Phrases.load("test_phrases_testSaveLoadStringScoring_temp_save.pkl") | ||
bigram.save(fpath) | ||
bigram_loaded = Phrases.load(fpath) | ||
seen_scores = set() | ||
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] | ||
for phrase, score in bigram_loaded.export_phrases(test_sentences): | ||
|
@@ -396,19 +402,15 @@ def testSaveLoadStringScoring(self): | |
3.444 # score for human interface | ||
]) | ||
|
||
finally: | ||
if os.path.exists("test_phrases_testSaveLoadStringScoring_temp_save.pkl"): | ||
os.remove("test_phrases_testSaveLoadStringScoring_temp_save.pkl") | ||
|
||
def testSaveLoadNoScoring(self): | ||
""" Saving and loading a Phrases object with no scoring parameter. | ||
This should ensure backwards compatibility with old versions of Phrases""" | ||
|
||
try: | ||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phrases(self.sentences, min_count=1, threshold=1) | ||
del(bigram.scoring) | ||
bigram.save("test_phrases_testSaveLoadNoScoring_temp_save.pkl") | ||
bigram_loaded = Phrases.load("test_phrases_testSaveLoadNoScoring_temp_save.pkl") | ||
bigram.save(fpath) | ||
bigram_loaded = Phrases.load(fpath) | ||
seen_scores = set() | ||
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] | ||
for phrase, score in bigram_loaded.export_phrases(test_sentences): | ||
|
@@ -419,9 +421,78 @@ def testSaveLoadNoScoring(self): | |
3.444 # score for human interface | ||
]) | ||
|
||
finally: | ||
if os.path.exists("test_phrases_testSaveLoadNoScoring_temp_save.pkl"): | ||
os.remove("test_phrases_testSaveLoadNoScoring_temp_save.pkl") | ||
def testSaveLoadNoCommonTerms(self): | ||
""" Saving and loading a Phrases objects without common_terms | ||
This should ensure backwards compatibility with old versions of Phrases""" | ||
|
||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phrases(self.sentences, min_count=1, threshold=1) | ||
del(bigram.common_terms) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better solution - checkout to the previous version, tran+save model and try to load old model here (in your current implementation, we possibly skip other bugs, that can be hidden due to the current bug) FYI - |
||
bigram.save(fpath) | ||
bigram_loaded = Phrases.load(fpath) | ||
self.assertEqual(bigram_loaded.common_terms, frozenset()) | ||
# can make a phraser, cf #1751 | ||
phraser = Phraser(bigram_loaded) # does not raise | ||
phraser["some terms"] # does not raise | ||
|
||
|
||
class TestPhraserPersistence(PhrasesData, unittest.TestCase): | ||
|
||
def testSaveLoadCustomScorer(self): | ||
"""Saving and loading a Phraser object with a custom scorer """ | ||
|
||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phraser( | ||
Phrases(self.sentences, min_count=1, threshold=.001, scoring=dumb_scorer)) | ||
bigram.save(fpath) | ||
bigram_loaded = Phraser.load(fpath) | ||
# we do not much with scoring, just verify its the one expected | ||
self.assertEqual(bigram_loaded.scoring, dumb_scorer) | ||
|
||
def testSaveLoad(self): | ||
""" Saving and loading a Phraser object.""" | ||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phraser(Phrases(self.sentences, min_count=1, threshold=1)) | ||
bigram.save(fpath) | ||
bigram_loaded = Phraser.load(fpath) | ||
self.assertEqual( | ||
bigram_loaded[['graph', 'minors', 'survey', 'human', 'interface', 'system']], | ||
['graph_minors', 'survey', 'human_interface', 'system']) | ||
|
||
def testSaveLoadStringScoring(self): | ||
""" Saving and loading a Phraser object with a string scoring parameter. | ||
This should ensure backwards compatibility with the previous version of Phraser""" | ||
|
||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phraser(Phrases(self.sentences, min_count=1, threshold=1)) | ||
bigram.scoring = "default" | ||
bigram.save(fpath) | ||
bigram_loaded = Phraser.load(fpath) | ||
# we do not much with scoring, just verify its the one expected | ||
self.assertEqual(bigram_loaded.scoring, original_scorer) | ||
|
||
def testSaveLoadNoScoring(self): | ||
""" Saving and loading a Phraser object with no scoring parameter. | ||
This should ensure backwards compatibility with old versions of Phraser""" | ||
|
||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phraser(Phrases(self.sentences, min_count=1, threshold=1)) | ||
del(bigram.scoring) | ||
bigram.save(fpath) | ||
bigram_loaded = Phraser.load(fpath) | ||
# we do not much with scoring, just verify its the one expected | ||
self.assertEqual(bigram_loaded.scoring, original_scorer) | ||
|
||
def testSaveLoadNoCommonTerms(self): | ||
""" Saving and loading a Phraser objects without common_terms | ||
This should ensure backwards compatibility with old versions of Phraser""" | ||
|
||
with temporary_file("test.pkl") as fpath: | ||
bigram = Phraser(Phrases(self.sentences, min_count=1, threshold=1)) | ||
del(bigram.common_terms) | ||
bigram.save(fpath) | ||
bigram_loaded = Phraser.load(fpath) | ||
self.assertEqual(bigram_loaded.common_terms, frozenset()) | ||
|
||
|
||
class TestPhraserModel(PhrasesData, PhrasesCommon, unittest.TestCase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gensim uses the
six
library to bridge py2/py3.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, note that it was from a previous commit…