Skip to content

Commit

Permalink
Merge pull request #3197 from silviatti/fix_3181
Browse files Browse the repository at this point in the history
Fix computation of topic coherence
  • Loading branch information
piskvorky authored Apr 25, 2022
2 parents 533da75 + 298880b commit 7cb443b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
11 changes: 11 additions & 0 deletions gensim/test/test_coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def setUp(self):
['not a token', 'not an id', 'tests using', "this list"],
['should raise', 'an error', 'to pass', 'correctly']
]
# list of topics with unseen words in the dictionary
self.topics5 = [
['aaaaa', 'bbbbb', 'ccccc', 'eeeee'],
['ddddd', 'fffff', 'ggggh', 'hhhhh']
]
self.topicIds1 = []
for topic in self.topics1:
self.topicIds1.append([self.dictionary.token2id[token] for token in topic])
Expand All @@ -70,8 +75,14 @@ def check_coherence_measure(self, coherence):
cm2 = CoherenceModel(topics=self.topics2, **kwargs)
cm3 = CoherenceModel(topics=self.topics3, **kwargs)
cm4 = CoherenceModel(topics=self.topicIds1, **kwargs)

# check if the same topic always returns the same coherence value
cm5 = CoherenceModel(topics=[self.topics1[0]], **kwargs)

self.assertRaises(ValueError, lambda: CoherenceModel(topics=self.topics4, **kwargs))
self.assertRaises(ValueError, lambda: CoherenceModel(topics=self.topics5, **kwargs))
self.assertEqual(cm1.get_coherence(), cm4.get_coherence())
self.assertEqual(cm1.get_coherence_per_topic()[0], cm5.get_coherence())
self.assertIsInstance(cm3.get_coherence(), np.double)
self.assertGreater(cm1.get_coherence(), cm2.get_coherence())

Expand Down
17 changes: 5 additions & 12 deletions gensim/topic_coherence/text_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,18 +300,11 @@ def accumulate(self, texts, window_size):
def _iter_texts(self, texts):
dtype = np.uint16 if np.iinfo(np.uint16).max >= self._vocab_size else np.uint32
for text in texts:
if self.text_is_relevant(text):
yield np.fromiter((
self.id2contiguous[self.token2id[w]] if w in self.relevant_words
else self._none_token
for w in text), dtype=dtype, count=len(text))

def text_is_relevant(self, text):
"""Check if the text has any relevant words."""
for word in text:
if word in self.relevant_words:
return True
return False
ids = (
self.id2contiguous[self.token2id[w]] if w in self.relevant_words else self._none_token
for w in text
)
yield np.fromiter(ids, dtype=dtype, count=len(text))


class InvertedIndexAccumulator(WindowedTextsAnalyzer, InvertedIndexBased):
Expand Down

0 comments on commit 7cb443b

Please sign in to comment.