Skip to content

Commit

Permalink
Fix backward compatibility problem in Phrases.load. Fix #1751 (#1758)
Browse files Browse the repository at this point in the history
* backward compatibility for Phrases models without common_terms

* Phraser also needs compatible load for versions without scoring or common_terms

* minor: simplify persitence tests in test_phrases by using a context manager for temporary file management

* using six for python compatibility in phrases

* better tests for phrases load backward compatibility (this also fix a bug in loading  phrases model before scoring). Also moving temporary_file context manager in gensim.test.utils

* pep8 fix

* fix imports, reuse datapath

* remove unused import
  • Loading branch information
alexgarel authored and menshikh-iv committed Dec 6, 2017
1 parent 09a16d1 commit a7120d7
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 98 deletions.
91 changes: 54 additions & 37 deletions gensim/models/phrases.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,46 @@ def analyze_sentence(self, sentence, threshold, common_terms, scorer):
yield (word, None)


class Phrases(SentenceAnalyzer, interfaces.TransformationABC):
class PhrasesTransformation(interfaces.TransformationABC):

@classmethod
def load(cls, *args, **kwargs):
"""
Load a previously saved Phrases/Phraser class. Handles backwards compatibility from
older Phrases/Phraser versions which did not support pluggable scoring functions.
Otherwise, relies on utils.load
"""

model = super(PhrasesTransformation, cls).load(*args, **kwargs)
# update older models
# if no scoring parameter, use default scoring
if not hasattr(model, 'scoring'):
logger.info('older version of %s loaded without scoring function', cls.__name__)
logger.info('setting pluggable scoring method to original_scorer for compatibility')
model.scoring = original_scorer
# if there is a scoring parameter, and it's a text value, load the proper scoring function
if hasattr(model, 'scoring'):
if isinstance(model.scoring, six.string_types):
if model.scoring == 'default':
logger.info('older version of %s loaded with "default" scoring parameter', cls.__name__)
logger.info('setting scoring method to original_scorer pluggable scoring method for compatibility')
model.scoring = original_scorer
elif model.scoring == 'npmi':
logger.info('older version of %s loaded with "npmi" scoring parameter', cls.__name__)
logger.info('setting scoring method to npmi_scorer pluggable scoring method for compatibility')
model.scoring = npmi_scorer
else:
raise ValueError(
'failed to load %s model with unknown scoring setting %s' % (cls.__name__, model.scoring))
# if there is non common_terms attribute, initialize
if not hasattr(model, "common_terms"):
logger.info('older version of %s loaded without common_terms attribute', cls.__name__)
logger.info('setting common_terms to empty set')
model.common_terms = frozenset()
return model


class Phrases(SentenceAnalyzer, PhrasesTransformation):
"""
Detect phrases, based on collected collocation counts. Adjacent words that appear
together more frequently than expected are joined together with the `_` character.
Expand Down Expand Up @@ -303,6 +342,19 @@ def __init__(self, sentences=None, min_count=5, threshold=10.0,
if sentences is not None:
self.add_vocab(sentences)

@classmethod
def load(cls, *args, **kwargs):
"""
Load a previously saved Phrases class. Handles backwards compatibility from
older Phrases versions which did not support pluggable scoring functions.
"""
model = super(Phrases, cls).load(*args, **kwargs)
if not hasattr(model, 'corpus_word_count'):
logger.info('older version of %s loaded without corpus_word_count', cls.__name__)
logger.info('Setting it to 0, do not use it in your scoring function.')
model.corpus_word_count = 0
return model

def __str__(self):
"""Get short string representation of this phrase detector."""
return "%s<%i vocab, min_count=%s, threshold=%s, max_vocab_size=%s>" % (
Expand Down Expand Up @@ -461,41 +513,6 @@ def __getitem__(self, sentence):

return [utils.to_unicode(w) for w in new_s]

@classmethod
def load(cls, *args, **kwargs):
"""
Load a previously saved Phrases class. Handles backwards compatibility from
older Phrases versions which did not support pluggable scoring functions. Otherwise, relies on utils.load
"""

# for python 2 and 3 compatibility. basestring is used to check if model.scoring is a string
try:
basestring
except NameError:
basestring = str

model = super(Phrases, cls).load(*args, **kwargs)
# update older models
# if no scoring parameter, use default scoring
if not hasattr(model, 'scoring'):
logger.info('older version of Phrases loaded without scoring function')
logger.info('setting pluggable scoring method to original_scorer for compatibility')
model.scoring = original_scorer
# if there is a scoring parameter, and it's a text value, load the proper scoring function
if hasattr(model, 'scoring'):
if isinstance(model.scoring, basestring):
if model.scoring == 'default':
logger.info('older version of Phrases loaded with "default" scoring parameter')
logger.info('setting scoring method to original_scorer pluggable scoring method for compatibility')
model.scoring = original_scorer
elif model.scoring == 'npmi':
logger.info('older version of Phrases loaded with "npmi" scoring parameter')
logger.info('setting scoring method to npmi_scorer pluggable scoring method for compatibility')
model.scoring = npmi_scorer
else:
raise ValueError('failed to load Phrases model with unknown scoring setting %s' % (model.scoring))
return model


# these two built-in scoring methods don't cast everything to float because the casting is done in the call
# to the scoring method in __getitem__ and export_phrases.
Expand Down Expand Up @@ -530,7 +547,7 @@ def pseudocorpus(source_vocab, sep, common_terms=frozenset()):
yield components


class Phraser(SentenceAnalyzer, interfaces.TransformationABC):
class Phraser(SentenceAnalyzer, PhrasesTransformation):
"""
Minimal state & functionality to apply results of a Phrases model to tokens.
Expand Down
Binary file added gensim/test/test_data/phraser-no-common-terms.pkl
Binary file not shown.
Binary file added gensim/test/test_data/phraser-no-scoring.pkl
Binary file not shown.
Binary file added gensim/test/test_data/phraser-scoring-str.pkl
Binary file not shown.
Binary file added gensim/test/test_data/phrases-no-common-terms.pkl
Binary file not shown.
Binary file added gensim/test/test_data/phrases-no-scoring.pkl
Binary file not shown.
Binary file added gensim/test/test_data/phrases-scoring-str.pkl
Binary file not shown.
146 changes: 85 additions & 61 deletions gensim/test/test_phrases.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@

import logging
import unittest
import os
import sys

from gensim import utils
from gensim.models.phrases import SentenceAnalyzer, Phrases, Phraser, pseudocorpus
from gensim.test.utils import common_texts
import six

if sys.version_info[0] >= 3:
unicode = str
from gensim.utils import to_unicode
from gensim.models.phrases import SentenceAnalyzer, Phrases, Phraser
from gensim.models.phrases import pseudocorpus, original_scorer
from gensim.test.utils import common_texts, temporary_file, datapath


class TestUtils(unittest.TestCase):
Expand Down Expand Up @@ -138,7 +136,7 @@ class PhrasesData:
sentences = common_texts + [
['graph', 'minors', 'survey', 'human', 'interface']
]
unicode_sentences = [[utils.to_unicode(w) for w in sentence] for sentence in sentences]
unicode_sentences = [[to_unicode(w) for w in sentence] for sentence in sentences]
common_terms = frozenset()

bigram1 = u'response_time'
Expand Down Expand Up @@ -230,7 +228,7 @@ def testEncoding(self):
self.assertEqual(self.bigram_unicode[self.sentences[1]], expected)

transformed = ' '.join(self.bigram_utf8[self.sentences[1]])
self.assertTrue(isinstance(transformed, unicode))
self.assertTrue(isinstance(transformed, six.text_type))


# scorer for testCustomScorer
Expand Down Expand Up @@ -335,15 +333,15 @@ def testPruning(self):
# endclass TestPhrasesModel


class TestPhrasesScoringPersistence(PhrasesData, unittest.TestCase):
class TestPhrasesPersistence(PhrasesData, unittest.TestCase):

def testSaveLoadCustomScorer(self):
""" saving and loading a Phrases object with a custom scorer """

try:
with temporary_file("test.pkl") as fpath:
bigram = Phrases(self.sentences, min_count=1, threshold=.001, scoring=dumb_scorer)
bigram.save("test_phrases_testSaveLoadCustomScorer_temp_save.pkl")
bigram_loaded = Phrases.load("test_phrases_testSaveLoadCustomScorer_temp_save.pkl")
bigram.save(fpath)
bigram_loaded = Phrases.load(fpath)
seen_scores = []
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']]
for phrase, score in bigram_loaded.export_phrases(test_sentences):
Expand All @@ -352,17 +350,13 @@ def testSaveLoadCustomScorer(self):
assert all(seen_scores) # all scores 1
assert len(seen_scores) == 3 # 'graph minors' and 'survey human' and 'interface system'

finally:
if os.path.exists("test_phrases_testSaveLoadCustomScorer_temp_save.pkl"):
os.remove("test_phrases_testSaveLoadCustomScorer_temp_save.pkl")

def testSaveLoad(self):
""" Saving and loading a Phrases object."""

try:
with temporary_file("test.pkl") as fpath:
bigram = Phrases(self.sentences, min_count=1, threshold=1)
bigram.save("test_phrases_testSaveLoad_temp_save.pkl")
bigram_loaded = Phrases.load("test_phrases_testSaveLoad_temp_save.pkl")
bigram.save(fpath)
bigram_loaded = Phrases.load(fpath)
seen_scores = set()
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']]
for phrase, score in bigram_loaded.export_phrases(test_sentences):
Expand All @@ -373,55 +367,85 @@ def testSaveLoad(self):
3.444 # score for human interface
])

finally:
if os.path.exists("test_phrases_testSaveLoad_temp_save.pkl"):
os.remove("test_phrases_testSaveLoad_temp_save.pkl")

def testSaveLoadStringScoring(self):
""" Saving and loading a Phrases object with a string scoring parameter.
This should ensure backwards compatibility with the previous version of Phrases"""
bigram_loaded = Phrases.load(datapath("phrases-scoring-str.pkl"))
seen_scores = set()
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']]
for phrase, score in bigram_loaded.export_phrases(test_sentences):
seen_scores.add(round(score, 3))

try:
bigram = Phrases(self.sentences, min_count=1, threshold=1)
bigram.scoring = "default"
bigram.save("test_phrases_testSaveLoadStringScoring_temp_save.pkl")
bigram_loaded = Phrases.load("test_phrases_testSaveLoadStringScoring_temp_save.pkl")
seen_scores = set()
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']]
for phrase, score in bigram_loaded.export_phrases(test_sentences):
seen_scores.add(round(score, 3))

assert seen_scores == set([
5.167, # score for graph minors
3.444 # score for human interface
])

finally:
if os.path.exists("test_phrases_testSaveLoadStringScoring_temp_save.pkl"):
os.remove("test_phrases_testSaveLoadStringScoring_temp_save.pkl")
assert seen_scores == set([
5.167, # score for graph minors
3.444 # score for human interface
])

def testSaveLoadNoScoring(self):
""" Saving and loading a Phrases object with no scoring parameter.
This should ensure backwards compatibility with old versions of Phrases"""

try:
bigram = Phrases(self.sentences, min_count=1, threshold=1)
del(bigram.scoring)
bigram.save("test_phrases_testSaveLoadNoScoring_temp_save.pkl")
bigram_loaded = Phrases.load("test_phrases_testSaveLoadNoScoring_temp_save.pkl")
seen_scores = set()
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']]
for phrase, score in bigram_loaded.export_phrases(test_sentences):
seen_scores.add(round(score, 3))
bigram_loaded = Phrases.load(datapath("phrases-no-scoring.pkl"))
seen_scores = set()
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']]
for phrase, score in bigram_loaded.export_phrases(test_sentences):
seen_scores.add(round(score, 3))

assert seen_scores == set([
5.167, # score for graph minors
3.444 # score for human interface
])
assert seen_scores == set([
5.167, # score for graph minors
3.444 # score for human interface
])

def testSaveLoadNoCommonTerms(self):
""" Ensure backwards compatibility with old versions of Phrases, before common_terms"""
bigram_loaded = Phrases.load(datapath("phrases-no-common-terms.pkl"))
self.assertEqual(bigram_loaded.common_terms, frozenset())
# can make a phraser, cf #1751
phraser = Phraser(bigram_loaded) # does not raise
phraser[["human", "interface", "survey"]] # does not raise


class TestPhraserPersistence(PhrasesData, unittest.TestCase):

def testSaveLoadCustomScorer(self):
"""Saving and loading a Phraser object with a custom scorer """

with temporary_file("test.pkl") as fpath:
bigram = Phraser(
Phrases(self.sentences, min_count=1, threshold=.001, scoring=dumb_scorer))
bigram.save(fpath)
bigram_loaded = Phraser.load(fpath)
# we do not much with scoring, just verify its the one expected
self.assertEqual(bigram_loaded.scoring, dumb_scorer)

def testSaveLoad(self):
""" Saving and loading a Phraser object."""
with temporary_file("test.pkl") as fpath:
bigram = Phraser(Phrases(self.sentences, min_count=1, threshold=1))
bigram.save(fpath)
bigram_loaded = Phraser.load(fpath)
self.assertEqual(
bigram_loaded[['graph', 'minors', 'survey', 'human', 'interface', 'system']],
['graph_minors', 'survey', 'human_interface', 'system'])

def testSaveLoadStringScoring(self):
""" Saving and loading a Phraser object with a string scoring parameter.
This should ensure backwards compatibility with the previous version of Phraser"""
bigram_loaded = Phraser.load(datapath("phraser-scoring-str.pkl"))
# we do not much with scoring, just verify its the one expected
self.assertEqual(bigram_loaded.scoring, original_scorer)

def testSaveLoadNoScoring(self):
""" Saving and loading a Phraser object with no scoring parameter.
This should ensure backwards compatibility with old versions of Phraser"""
bigram_loaded = Phraser.load(datapath("phraser-no-scoring.pkl"))
# we do not much with scoring, just verify its the one expected
self.assertEqual(bigram_loaded.scoring, original_scorer)

finally:
if os.path.exists("test_phrases_testSaveLoadNoScoring_temp_save.pkl"):
os.remove("test_phrases_testSaveLoadNoScoring_temp_save.pkl")
def testSaveLoadNoCommonTerms(self):
""" Ensure backwards compatibility with old versions of Phraser, before common_terms"""
bigram_loaded = Phraser.load(datapath("phraser-no-common-terms.pkl"))
self.assertEqual(bigram_loaded.common_terms, frozenset())


class TestPhraserModel(PhrasesData, PhrasesCommon, unittest.TestCase):
Expand Down Expand Up @@ -461,7 +485,7 @@ class CommonTermsPhrasesData:
['data', 'and', 'graph', 'survey'],
['data', 'and', 'graph', 'survey', 'for', 'human', 'interface'] # test bigrams within same sentence
]
unicode_sentences = [[utils.to_unicode(w) for w in sentence] for sentence in sentences]
unicode_sentences = [[to_unicode(w) for w in sentence] for sentence in sentences]
common_terms = ['of', 'and', 'for']

bigram1 = u'lack_of_interest'
Expand All @@ -486,7 +510,7 @@ def testEncoding(self):
self.assertEqual(self.bigram_unicode[self.sentences[1]], expected)

transformed = ' '.join(self.bigram_utf8[self.sentences[1]])
self.assertTrue(isinstance(transformed, unicode))
self.assertTrue(isinstance(transformed, six.text_type))

def testMultipleBigramsSingleEntry(self):
""" a single entry should produce multiple bigrams. """
Expand Down Expand Up @@ -593,7 +617,7 @@ def testEncoding(self):
self.assertEqual(self.bigram_unicode[self.sentences[1]], expected)

transformed = ' '.join(self.bigram_utf8[self.sentences[1]])
self.assertTrue(isinstance(transformed, unicode))
self.assertTrue(isinstance(transformed, six.text_type))


if __name__ == '__main__':
Expand Down
18 changes: 18 additions & 0 deletions gensim/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
"""
Common utils for tests
"""
import contextlib
import tempfile
import os
import shutil

from gensim.corpora import Dictionary

Expand All @@ -27,6 +29,22 @@ def get_tmpfile(suffix):
return os.path.join(tempfile.gettempdir(), suffix)


@contextlib.contextmanager
def temporary_file(name=""):
"""create a temporary directory and return a path to "name" in that directory
At the end of the context, the directory is removed.
The function doesn't create the file.
"""
# note : when dropping python2.7 support, we can use tempfile.TemporaryDirectory
tmp = tempfile.mkdtemp()
try:
yield os.path.join(tmp, name)
finally:
shutil.rmtree(tmp, ignore_errors=True)


# set up vars used in testing ("Deerwester" from the web tutorial)
common_texts = [
['human', 'interface', 'computer'],
Expand Down

0 comments on commit a7120d7

Please sign in to comment.