Skip to content

Commit

Permalink
loss: always tally; split to epoch_loss/minibatch_loss; use wider float
Browse files Browse the repository at this point in the history
  • Loading branch information
gojomo committed Aug 28, 2020
1 parent 817cac9 commit 33ef202
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 117 deletions.
41 changes: 11 additions & 30 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,19 @@
CORPUSFILE_VERSION = -1

def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expected_examples, _expected_words,
_work, _neu1, compute_loss):
_work, _neu1):
raise RuntimeError("Training with corpus_file argument is not supported")

def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expected_examples, _expected_words,
_work, _neu1, compute_loss):
_work, _neu1):
raise RuntimeError("Training with corpus_file argument is not supported")


class Word2Vec(utils.SaveLoad):
def __init__(self, sentences=None, corpus_file=None, vector_size=100, alpha=0.025, window=5, min_count=5,
max_vocab_size=None, sample=1e-3, seed=1, workers=3, min_alpha=0.0001,
sg=0, hs=0, negative=5, ns_exponent=0.75, cbow_mean=1, hashfxn=hash, epochs=5, null_word=0,
trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH, compute_loss=False, callbacks=(),
trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH, callbacks=(),
comment=None, max_final_vocab=None):
"""Train, use and evaluate neural networks described in https://code.google.com/p/word2vec/.
Expand Down Expand Up @@ -282,9 +282,6 @@ def __init__(self, sentences=None, corpus_file=None, vector_size=100, alpha=0.02
Target size (in words) for batches of examples passed to worker threads (and
thus cython routines).(Larger batches will be passed if individual
texts are longer than 10000 words, but the standard cython code truncates to that maximum.)
compute_loss: bool, optional
If True, computes and stores loss value which can be retrieved using
:meth:`~gensim.models.word2vec.Word2Vec.get_latest_training_loss`.
callbacks : iterable of :class:`~gensim.models.callbacks.CallbackAny2Vec`, optional
Sequence of callbacks to be executed at specific stages during training.
Expand Down Expand Up @@ -325,8 +322,7 @@ def __init__(self, sentences=None, corpus_file=None, vector_size=100, alpha=0.02
self.negative = int(negative)
self.ns_exponent = ns_exponent
self.cbow_mean = int(cbow_mean)
self.compute_loss = bool(compute_loss)
self.running_training_loss = 0
self.epoch_loss = 0.0
self.min_alpha_yet_reached = float(alpha)
self.corpus_count = 0
self.corpus_total_words = 0
Expand Down Expand Up @@ -380,7 +376,7 @@ def build_vocab_and_train(self, corpus_iterable=None, corpus_file=None, trim_rul
self.train(
corpus_iterable=corpus_iterable, corpus_file=corpus_file, total_examples=self.corpus_count,
total_words=self.corpus_total_words, epochs=self.epochs, start_alpha=self.alpha,
end_alpha=self.min_alpha, compute_loss=self.compute_loss, callbacks=callbacks)
end_alpha=self.min_alpha, callbacks=callbacks)

def build_vocab(self, corpus_iterable=None, corpus_file=None, update=False, progress_per=10000,
keep_raw_vocab=False, trim_rule=None, **kwargs):
Expand Down Expand Up @@ -838,10 +834,10 @@ def _do_train_epoch(self, corpus_file, thread_id, offset, cython_vocab, thread_p

if self.sg:
examples, tally, raw_tally = train_epoch_sg(self, corpus_file, offset, cython_vocab, cur_epoch,
total_examples, total_words, work, neu1, self.compute_loss)
total_examples, total_words, work, neu1)
else:
examples, tally, raw_tally = train_epoch_cbow(self, corpus_file, offset, cython_vocab, cur_epoch,
total_examples, total_words, work, neu1, self.compute_loss)
total_examples, total_words, work, neu1)

return examples, tally, raw_tally

Expand All @@ -866,9 +862,9 @@ def _do_train_job(self, sentences, alpha, inits):
work, neu1 = inits
tally = 0
if self.sg:
tally += train_batch_sg(self, sentences, alpha, work, self.compute_loss)
tally += train_batch_sg(self, sentences, alpha, work)
else:
tally += train_batch_cbow(self, sentences, alpha, work, neu1, self.compute_loss)
tally += train_batch_cbow(self, sentences, alpha, work, neu1)
return tally, self._raw_word_count(sentences)

def _clear_post_train(self):
Expand All @@ -877,7 +873,7 @@ def _clear_post_train(self):

def train(self, corpus_iterable=None, corpus_file=None, total_examples=None, total_words=None,
epochs=None, start_alpha=None, end_alpha=None, word_count=0,
queue_factor=2, report_delay=1.0, compute_loss=False, callbacks=(),
queue_factor=2, report_delay=1.0, callbacks=(),
**kwargs):
"""Update the model's neural weights from a sequence of sentences.
Expand Down Expand Up @@ -931,9 +927,6 @@ def train(self, corpus_iterable=None, corpus_file=None, total_examples=None, tot
Multiplier for size of queue (number of workers * queue_factor).
report_delay : float, optional
Seconds to wait before reporting progress.
compute_loss: bool, optional
If True, computes and stores loss value which can be retrieved using
:meth:`~gensim.models.word2vec.Word2Vec.get_latest_training_loss`.
callbacks : iterable of :class:`~gensim.models.callbacks.CallbackAny2Vec`, optional
Sequence of callbacks to be executed at specific stages during training.
Expand All @@ -959,8 +952,7 @@ def train(self, corpus_iterable=None, corpus_file=None, total_examples=None, tot
total_examples=total_examples,
total_words=total_words)

self.compute_loss = compute_loss
self.running_training_loss = 0.0
self.epoch_loss = 0.0

for callback in callbacks:
callback.on_train_begin(self)
Expand Down Expand Up @@ -1820,17 +1812,6 @@ def save(self, *args, **kwargs):
kwargs['ignore'] = kwargs.get('ignore', []) + ['cum_table', ]
super(Word2Vec, self).save(*args, **kwargs)

def get_latest_training_loss(self):
"""Get current value of the training loss.
Returns
-------
float
Current training loss.
"""
return self.running_training_loss

@classmethod
def load(cls, *args, rethrow=False, **kwargs):
"""Load a previously saved :class:`~gensim.models.word2vec.Word2Vec` model.
Expand Down
30 changes: 13 additions & 17 deletions gensim/models/word2vec_corpusfile.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ cdef REAL_t get_next_alpha(


def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expected_examples, _expected_words, _work,
_neu1, compute_loss):
_neu1):
"""Train Skipgram model for one epoch by training on an input stream. This function is used only in multistream mode.
Called internally from :meth:`~gensim.models.word2vec.Word2Vec.train`.
Expand All @@ -268,8 +268,6 @@ def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expec
Private working memory for each worker.
_neu1 : np.ndarray
Private working memory for each worker.
compute_loss : bool
Whether or not the training loss should be computed in this batch.
Returns
-------
Expand Down Expand Up @@ -297,7 +295,7 @@ def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expec
cdef long long total_effective_words = 0, total_words = 0
cdef int sent_idx, idx_start, idx_end

init_w2v_config(&c, model, _alpha, compute_loss, _work)
init_w2v_config(&c, model, _alpha, _work)

cdef vector[vector[string]] sentences

Expand Down Expand Up @@ -330,14 +328,14 @@ def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expec
if c.hs:
w2v_fast_sentence_sg_hs(
c.points[i], c.codes[i], c.codelens[i], c.syn0, c.syn1, c.size, c.indexes[j],
c.alpha, c.work, c.words_lockf, c.words_lockf_len, c.compute_loss,
&c.running_training_loss)
c.alpha, c.work, c.words_lockf, c.words_lockf_len,
&c.minibatch_loss)
if c.negative:
c.next_random = w2v_fast_sentence_sg_neg(
c.negative, c.cum_table, c.cum_table_len, c.syn0, c.syn1neg, c.size,
c.indexes[i], c.indexes[j], c.alpha, c.work, c.next_random,
c.words_lockf, c.words_lockf_len,
c.compute_loss, &c.running_training_loss)
&c.minibatch_loss)

total_sentences += sentences.size()
total_effective_words += effective_words
Expand All @@ -346,12 +344,12 @@ def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expec
start_alpha, end_alpha, total_sentences, total_words,
expected_examples, expected_words, cur_epoch, num_epochs)

model.running_training_loss = c.running_training_loss
model.epoch_loss += c.minibatch_loss
return total_sentences, total_effective_words, total_words


def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expected_examples, _expected_words, _work,
_neu1, compute_loss):
_neu1):
"""Train CBOW model for one epoch by training on an input stream. This function is used only in multistream mode.
Called internally from :meth:`~gensim.models.word2vec.Word2Vec.train`.
Expand All @@ -368,8 +366,6 @@ def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _exp
Private working memory for each worker.
_neu1 : np.ndarray
Private working memory for each worker.
compute_loss : bool
Whether or not the training loss should be computed in this batch.
Returns
-------
Expand Down Expand Up @@ -397,7 +393,7 @@ def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _exp
cdef long long total_effective_words = 0, total_words = 0
cdef int sent_idx, idx_start, idx_end

init_w2v_config(&c, model, _alpha, compute_loss, _work, _neu1)
init_w2v_config(&c, model, _alpha, _work, _neu1)

cdef vector[vector[string]] sentences

Expand Down Expand Up @@ -427,15 +423,15 @@ def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _exp
if c.hs:
w2v_fast_sentence_cbow_hs(
c.points[i], c.codes[i], c.codelens, c.neu1, c.syn0, c.syn1, c.size, c.indexes, c.alpha,
c.work, i, j, k, c.cbow_mean, c.words_lockf, c.words_lockf_len, c.compute_loss,
&c.running_training_loss)
c.work, i, j, k, c.cbow_mean, c.words_lockf, c.words_lockf_len,
&c.minibatch_loss)

if c.negative:
c.next_random = w2v_fast_sentence_cbow_neg(
c.negative, c.cum_table, c.cum_table_len, c.codelens, c.neu1, c.syn0,
c.syn1neg, c.size, c.indexes, c.alpha, c.work, i, j, k, c.cbow_mean,
c.next_random, c.words_lockf, c.words_lockf_len, c.compute_loss,
&c.running_training_loss)
c.next_random, c.words_lockf, c.words_lockf_len,
&c.minibatch_loss)

total_sentences += sentences.size()
total_effective_words += effective_words
Expand All @@ -444,7 +440,7 @@ def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _exp
start_alpha, end_alpha, total_sentences, total_words,
expected_examples, expected_words, cur_epoch, num_epochs)

model.running_training_loss = c.running_training_loss
model.epoch_loss += c.minibatch_loss
return total_sentences, total_effective_words, total_words


Expand Down
15 changes: 8 additions & 7 deletions gensim/models/word2vec_inner.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ cdef our_saxpy_ptr our_saxpy


cdef struct Word2VecConfig:
int hs, negative, sample, compute_loss, size, window, cbow_mean, workers
REAL_t running_training_loss, alpha
int hs, negative, sample, size, window, cbow_mean, workers
REAL_t alpha
np.float64_t minibatch_loss

REAL_t *syn0
REAL_t *words_lockf
Expand Down Expand Up @@ -96,31 +97,31 @@ cdef void w2v_fast_sentence_sg_hs(
const np.uint32_t *word_point, const np.uint8_t *word_code, const int codelen,
REAL_t *syn0, REAL_t *syn1, const int size,
const np.uint32_t word2_index, const REAL_t alpha, REAL_t *work, REAL_t *words_lockf,
const np.uint32_t lockf_len, const int _compute_loss, REAL_t *_running_training_loss_param) nogil
const np.uint32_t lockf_len, np.float64_t *minibatch_loss_ptr) nogil


cdef unsigned long long w2v_fast_sentence_sg_neg(
const int negative, np.uint32_t *cum_table, unsigned long long cum_table_len,
REAL_t *syn0, REAL_t *syn1neg, const int size, const np.uint32_t word_index,
const np.uint32_t word2_index, const REAL_t alpha, REAL_t *work,
unsigned long long next_random, REAL_t *words_lockf,
const np.uint32_t lockf_len, const int _compute_loss, REAL_t *_running_training_loss_param) nogil
const np.uint32_t lockf_len, np.float64_t *minibatch_loss_ptr) nogil


cdef void w2v_fast_sentence_cbow_hs(
const np.uint32_t *word_point, const np.uint8_t *word_code, int codelens[MAX_SENTENCE_LEN],
REAL_t *neu1, REAL_t *syn0, REAL_t *syn1, const int size,
const np.uint32_t indexes[MAX_SENTENCE_LEN], const REAL_t alpha, REAL_t *work,
int i, int j, int k, int cbow_mean, REAL_t *words_lockf,
const np.uint32_t lockf_len, const int _compute_loss, REAL_t *_running_training_loss_param) nogil
const np.uint32_t lockf_len, np.float64_t *minibatch_loss_ptr) nogil


cdef unsigned long long w2v_fast_sentence_cbow_neg(
const int negative, np.uint32_t *cum_table, unsigned long long cum_table_len, int codelens[MAX_SENTENCE_LEN],
REAL_t *neu1, REAL_t *syn0, REAL_t *syn1neg, const int size,
const np.uint32_t indexes[MAX_SENTENCE_LEN], const REAL_t alpha, REAL_t *work,
int i, int j, int k, int cbow_mean, unsigned long long next_random, REAL_t *words_lockf,
const np.uint32_t lockf_len, const int _compute_loss, REAL_t *_running_training_loss_param) nogil
const np.uint32_t lockf_len, np.float64_t *minibatch_loss_ptr) nogil


cdef init_w2v_config(Word2VecConfig *c, model, alpha, compute_loss, _work, _neu1=*)
cdef init_w2v_config(Word2VecConfig *c, model, alpha, _work, _neu1=*)
Loading

0 comments on commit 33ef202

Please sign in to comment.