Skip to content

Commit

Permalink
upgrade reloaded w2v models; proper total_examples/total_words for mo…
Browse files Browse the repository at this point in the history
…re training
  • Loading branch information
gojomo committed Jul 19, 2015
1 parent 7bbe02d commit aa8a9cd
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 21 deletions.
8 changes: 6 additions & 2 deletions gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,6 @@ def __init__(self, documents=None, size=300, alpha=0.025, window=8, min_count=5,
"""
super(Doc2Vec, self).__init__(
# super(self.__class__, self).__init__(
size=size, alpha=alpha, window=window, min_count=min_count, max_vocab_size=max_vocab_size,
sample=sample, seed=seed, workers=workers, min_alpha=min_alpha,
sg=(1+dm) % 2, hs=hs, negative=negative, cbow_mean=dm_mean,
Expand Down Expand Up @@ -634,6 +633,7 @@ def scan_vocab(self, documents, progress_per=10000):
def _do_train_job(self, job, alpha, inits):
work, neu1 = inits
tally = 0
raw_tally = 0
for doc in job:
indexed_doctags = self.docvecs.indexed_doctags(doc.tags)
doctag_indexes, doctag_vectors, doctag_locks, ignored = indexed_doctags
Expand All @@ -647,8 +647,12 @@ def _do_train_job(self, job, alpha, inits):
else:
tally += train_document_dm(self, doc.words, doctag_indexes, alpha, work, neu1,
doctag_vectors=doctag_vectors, doctag_locks=doctag_locks)
raw_tally += len(doc.words)
self.docvecs.trained_item(indexed_doctags)
return tally
return (tally, raw_tally)

def _raw_word_count(self, items):
return sum(len(item.words) for item in items)

def infer_vector(self, doc_words, alpha=0.1, min_alpha=0.0001, steps=5):
"""
Expand Down
97 changes: 78 additions & 19 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def scan_vocab(self, sentences, progress_per=10000):
min_reduce += 1

total_words += sum(itervalues(vocab))
logger.info("collected %i word types from a corpus of %i words and %i sentences",
logger.info("collected %i word types from a corpus of %i raw words and %i sentences",
len(vocab), total_words, sentence_no + 1)
self.corpus_count = sentence_no + 1
self.raw_vocab = vocab
Expand Down Expand Up @@ -623,17 +623,26 @@ def reset_from(self, other_model):
def _do_train_job(self, job, alpha, inits):
work, neu1 = inits
tally = 0
raw_tally = 0
for sentence in job:
if self.sg:
tally += train_sentence_sg(self, sentence, alpha, work)
else:
tally += train_sentence_cbow(self, sentence, alpha, work, neu1)
return tally
raw_tally += len(sentence)
return (tally, raw_tally)

def train(self, sentences, total_words=None, word_count=0, chunksize=100, queue_factor=2, report_delay=1):
def _raw_word_count(self, items):
return sum(len(item) for item in items)

def train(self, sentences, total_words=None, word_count=0, chunksize=100, total_examples=None, queue_factor=2, report_delay=1):
"""
Update the model's neural weights from a sequence of sentences (can be a once-only generator stream).
Each sentence must be a list of unicode strings.
For Word2Vec, each sentence must be a list of unicode strings. (Subclasses may accept other examples.)
To support linear learning-rate decay from (initial) alpha to min_alpha, either total_examples
(count of sentences) or total_words (count of raw words in sentences) should be provided, unless the
sentences are the same as those that were used to initially build the vocabulary.
"""
if FAST_VERSION < 0:
Expand All @@ -655,8 +664,17 @@ def train(self, sentences, total_words=None, word_count=0, chunksize=100, queue_
if not hasattr(self, 'syn0'):
raise RuntimeError("you must first finalize vocabulary before training the model")

if total_words is None and total_examples is None:
if self.corpus_count:
total_examples = self.corpus_count
logger.info("expecting %i examples, matching count from corpus used for vocabulary survey", total_examples)
else:
raise ValueError("you must provide either total_words or total_examples, to enable alpha and progress calculations")

if self.iter > 1:
sentences = utils.RepeatCorpusNTimes(sentences, self.iter)
total_words = total_words and total_words * self.iter
total_examples = total_examples and total_examples * self.iter

def worker_init():
work = matutils.zeros_aligned(self.layer1_size, dtype=REAL) # per-thread private work memory
Expand All @@ -668,8 +686,8 @@ def worker_one_job(job, inits):
if items is None: # signal to finish
return False
# train & return tally
job_words = self._do_train_job(items, alpha, inits)
progress_queue.put(job_words) # report progress
tally, raw_tally = self._do_train_job(items, alpha, inits)
progress_queue.put((len(items), tally, raw_tally)) # report progress
return True

def worker_loop():
Expand All @@ -681,8 +699,7 @@ def worker_loop():
break

start, next_report = default_timer(), 1.0
total_words = total_words or int(sum(v.count * (v.sample_int/2**32) for v in itervalues(self.vocab)) *
self.iter)

# buffer ahead only a limited number of jobs.. this is the reason we can't simply use ThreadPool :(
if self.workers > 0:
job_queue = Queue(maxsize=queue_factor * self.workers)
Expand All @@ -696,6 +713,10 @@ def worker_loop():
thread.start()

pushed_words = 0
pushed_examples = 0
example_count = 0
trained_word_count = 0
raw_word_count = word_count
push_done = False
done_jobs = 0
next_alpha = self.alpha
Expand All @@ -706,23 +727,39 @@ def worker_loop():
job_no, items = next(jobs_source)
logger.debug("putting job #%i in the queue at alpha %.05f", job_no, next_alpha)
job_queue.put((items, next_alpha))
# update the learning rate before every job
pushed_words += round((chunksize / (self.corpus_count * self.iter)) * total_words)
next_alpha = self.alpha - (self.alpha - self.min_alpha) * (pushed_words / total_words)
# update the learning rate before every next job
if self.min_alpha < next_alpha:
if total_examples:
# examples-based decay
pushed_examples += len(items)
next_alpha = self.alpha - (self.alpha - self.min_alpha) * (pushed_examples / total_examples)
else:
# words-based decay
pushed_words += self._raw_word_count(items)
next_alpha = self.alpha - (self.alpha - self.min_alpha) * (pushed_words / total_words)
next_alpha = max(next_alpha, self.min_alpha)
except StopIteration:
logger.info("reached end of input; waiting to finish %i outstanding jobs" % (job_no-done_jobs+1))
for _ in xrange(self.workers):
job_queue.put((None, 0)) # give the workers heads up that they can finish -- no more work!
push_done = True
try:
while done_jobs < (job_no+1) or not push_done:
word_count += progress_queue.get(push_done) # only block after all jobs pushed
examples, trained_words, raw_words = progress_queue.get(push_done) # only block after all jobs pushed
example_count += examples
trained_word_count += trained_words # only words in vocab & sampled
raw_word_count += raw_words
done_jobs += 1
elapsed = default_timer() - start
if elapsed >= next_report:
est_alpha = self.alpha - (self.alpha - self.min_alpha) * (word_count / total_words)
logger.info("PROGRESS: at %.2f%% words, alpha estimate %.05f, %.0f words/s",
100.0 * word_count / total_words, est_alpha, word_count / elapsed)
if total_examples:
# examples-based progress %
logger.info("PROGRESS: at %.2f%% examples, %.0f words/s",
100.0 * example_count / total_examples, trained_word_count / elapsed)
else:
# words-based progress %
logger.info("PROGRESS: at %.2f%% words, %.0f words/s",
100.0 * raw_word_count / total_words, trained_word_count / elapsed)
next_report = elapsed + report_delay # don't flood log, wait report_delay seconds
else:
# loop ended by job count; really done
Expand All @@ -731,12 +768,18 @@ def worker_loop():
pass # already out of loop; continue to next push

elapsed = default_timer() - start
logger.info("training on %i words took %.1fs, %.0f words/s" %
(word_count, elapsed, word_count / elapsed if elapsed else 0.0))
self.train_count += 1
logger.info("training on %i raw words took %.1fs, %.0f trained words/s" %
(raw_word_count, elapsed, trained_word_count / elapsed if elapsed else 0.0))

if total_examples and total_examples != example_count:
logger.warn("supplied example count (%i) did not equal expected count (%i)", example_count, total_examples)
if total_words and total_words != raw_word_count:
logger.warn("supplied raw word count (%i) did not equal expected count (%i)", raw_word_count, total_words)

self.train_count += 1 # number of times train() has been called
self.total_train_time += elapsed
self.clear_sims()
return word_count
return trained_word_count

def _score_job_words(self, sentence, work, neu1):
if self.sg:
Expand Down Expand Up @@ -1314,10 +1357,26 @@ def save(self, *args, **kwargs):
@classmethod
def load(cls, *args, **kwargs):
model = super(Word2Vec, cls).load(*args, **kwargs)
# update older models
if hasattr(model, 'table'):
delattr(model, 'table') # discard in favor of cum_table
if model.negative:
model.make_cum_table() # rebuild cum_table from vocabulary
if not hasattr(model, 'corpus_count'):
model.corpus_count = None
for v in model.vocab.values():
if hasattr(v, 'sample_int'):
break # already 0.12.0+ style int probabilities
else:
v.sample_int = int(round(v.sample_probability * 2**32))
del v.sample_probability
if not hasattr(model, 'syn0_lockf'):
model.syn0_lockf = ones(len(model.syn0), dtype=REAL)
if not hasattr(model, 'random'):
model.random = random.RandomState(model.seed)
if not hasattr(model, 'train_count'):
model.train_count = 0
model.total_train_time = 0
return model


Expand Down

0 comments on commit aa8a9cd

Please sign in to comment.