-
Notifications
You must be signed in to change notification settings - Fork 190
Stop criteria #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Stop criteria #79
Changes from 6 commits
a23650c
84bc0f8
e2b4b3b
4bb83d8
42acc60
114adfd
d8b56b0
72f4f7f
d4d9630
9e22191
9fc5caf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -211,12 +211,12 @@ def train(self): | |
| for i in xrange(len(train_bucket_sizes))] | ||
|
|
||
| # This is the training loop. | ||
| step_time, train_loss = 0.0, 0.0 | ||
| current_step, num_iter_wo_improve = 0, 0 | ||
| prev_train_losses, prev_valid_losses = [], [] | ||
| num_iter_cover_train = int(sum(train_bucket_sizes) / | ||
| self.params.batch_size / | ||
| self.params.steps_per_checkpoint) | ||
| step_time, train_loss, window_scale = 0.0, 0.0, 1.5 | ||
| current_step, iter_idx, num_epochs_last_impr, max_num_epochs = 0, 0, 0, 2 | ||
| prev_train_losses, prev_valid_losses, prev_epoch_valid_losses = [], [], [] | ||
| iter_per_epoch = max(1, int(sum(train_bucket_sizes) / | ||
| self.params.batch_size / | ||
| self.params.steps_per_checkpoint)) | ||
| while (self.params.max_steps == 0 | ||
| or self.model.global_step.eval(self.session) | ||
| <= self.params.max_steps): | ||
|
|
@@ -232,45 +232,62 @@ def train(self): | |
| # Print statistics for the previous steps. | ||
| train_ppx = math.exp(train_loss) if train_loss < 300 else float('inf') | ||
| print ("global step %d learning rate %.4f step-time %.2f perplexity " | ||
| "%.2f" % (self.model.global_step.eval(self.session), | ||
| "%.3f" % (self.model.global_step.eval(self.session), | ||
| self.model.learning_rate.eval(self.session), | ||
| step_time, train_ppx)) | ||
| eval_loss = self.__calc_eval_loss() | ||
| eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf') | ||
| print(" eval: perplexity %.2f" % (eval_ppx)) | ||
| print(" eval: perplexity %.3f" % (eval_ppx)) | ||
| # Decrease learning rate if no improvement was seen on train set | ||
| # over last 3 times. | ||
| if (len(prev_train_losses) > 2 | ||
| and train_loss > max(prev_train_losses[-3:])): | ||
| self.session.run(self.model.learning_rate_decay_op) | ||
|
|
||
| if (len(prev_valid_losses) > 0 | ||
| and eval_loss <= min(prev_valid_losses)): | ||
| # Save checkpoint and zero timer and loss. | ||
| self.model.saver.save(self.session, | ||
| os.path.join(self.model_dir, "model"), | ||
| write_meta_graph=False) | ||
|
|
||
| if (len(prev_valid_losses) > 0 | ||
| and eval_loss >= min(prev_valid_losses)): | ||
| num_iter_wo_improve += 1 | ||
| else: | ||
| num_iter_wo_improve = 0 | ||
|
|
||
| if num_iter_wo_improve > num_iter_cover_train * 2: | ||
| print("No improvement over last %d times. Training will stop after %d" | ||
| "iterations if no improvement was seen." | ||
| % (num_iter_wo_improve, | ||
| num_iter_cover_train - num_iter_wo_improve)) | ||
|
|
||
| # Stop train if no improvement was seen on validation set | ||
| # over last 3 epochs. | ||
| if num_iter_wo_improve > num_iter_cover_train * 3: | ||
| break | ||
| #if (len(prev_valid_losses) > 0 | ||
| # and eval_loss <= min(prev_valid_losses)): | ||
| # Save checkpoint and zero timer and loss. | ||
| self.model.saver.save(self.session, | ||
| os.path.join(self.model_dir, "model"), | ||
| write_meta_graph=False) | ||
|
|
||
| # After epoch pass, calculate average epoch loss | ||
| # and then make a decision to continue/stop training. | ||
| if (iter_idx > 0 | ||
| and iter_idx % iter_per_epoch == 0): | ||
| # Calculate average validation loss during the previous epoch | ||
| epoch_eval_loss = self.__calc_epoch_loss( | ||
| prev_valid_losses[-iter_per_epoch:]) | ||
| if len(prev_epoch_valid_losses) > 0: | ||
| print('Prev min epoch eval loss: %f, curr epoch eval loss: %f' % | ||
| (min(prev_epoch_valid_losses), epoch_eval_loss)) | ||
| # Check if there was an improvement during last epoch | ||
| if (epoch_eval_loss < min(prev_epoch_valid_losses)): | ||
| if num_epochs_last_impr > max_num_epochs/window_scale: | ||
| max_num_epochs = int(window_scale * num_epochs_last_impr) | ||
| print('Improved during last epoch.') | ||
| prev_min_level = prev_epoch_valid_losses[-1] | ||
| num_epochs_last_impr = 0 | ||
| else: | ||
| print('No improvement during last epoch.') | ||
| num_epochs_last_impr += 1 | ||
|
|
||
| print('Number of the epochs passed from the last improvement: %d' | ||
| % num_epochs_last_impr) | ||
| print('Max allowable number of epochs for improvement: %d' | ||
| % max_num_epochs) | ||
|
|
||
| # Stop training if no improvement was seen during last | ||
| # max_num_epochs epochs | ||
| if num_epochs_last_impr > max_num_epochs: | ||
| break | ||
|
|
||
| prev_epoch_valid_losses.append(round(epoch_eval_loss, 3)) | ||
|
|
||
| prev_train_losses.append(train_loss) | ||
| prev_valid_losses.append(eval_loss) | ||
| step_time, train_loss = 0.0, 0.0 | ||
| iter_idx += 1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use current step instead of iter_idx |
||
|
|
||
| print('Training done.') | ||
| with tf.Graph().as_default(): | ||
|
|
@@ -299,23 +316,43 @@ def __calc_step_loss(self, train_buckets_scale): | |
| def __calc_eval_loss(self): | ||
| """Run evals on development set and print their perplexity. | ||
| """ | ||
| eval_loss, num_iter_total = 0.0, 0.0 | ||
| eval_loss, iter_total = 0.0, 0.0 | ||
| for bucket_id in xrange(len(self._BUCKETS)): | ||
| num_iter_cover_valid = int(math.ceil(len(self.valid_set[bucket_id])/ | ||
| iter_per_valid = int(math.ceil(len(self.valid_set[bucket_id])/ | ||
| self.params.batch_size)) | ||
| num_iter_total += num_iter_cover_valid | ||
| for batch_id in xrange(num_iter_cover_valid): | ||
| iter_total += iter_per_valid | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Count iter_total in inner loop |
||
| for batch_id in xrange(iter_per_valid): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use xrange with batch_size step. |
||
| encoder_inputs, decoder_inputs, target_weights =\ | ||
| self.model.get_eval_set_batch(self.valid_set, bucket_id, | ||
| batch_id * self.params.batch_size) | ||
| _, eval_batch_loss, _ = self.model.step(self.session, encoder_inputs, | ||
| decoder_inputs, target_weights, | ||
| bucket_id, True) | ||
| eval_loss += eval_batch_loss | ||
| eval_loss = eval_loss/num_iter_total if num_iter_total > 0 else float('inf') | ||
| eval_loss = eval_loss/iter_total if iter_total > 0 else float('inf') | ||
| return eval_loss | ||
|
|
||
|
|
||
| def __calc_epoch_loss(self, epoch_losses, allow_excess_min=1.5): | ||
| """Calculate an average loss without outliers during the epoch. | ||
|
|
||
| Args: | ||
| epoch_losses: list of the losses during the epoch; | ||
|
|
||
| Returns: | ||
| the average value of the losses without outliers during the period; | ||
| """ | ||
| epoch_loss_sum, loss_num = 0, 0 | ||
| for loss in epoch_losses: | ||
| if loss < min(epoch_losses) * allow_excess_min: | ||
| epoch_loss_sum += loss | ||
| loss_num += 1 | ||
| if loss_num > 0: | ||
| return epoch_loss_sum / loss_num | ||
| else: | ||
| return float(inf) | ||
|
|
||
|
|
||
| def decode_word(self, word): | ||
| """Decode input word to sequence of phonemes. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
epochs_without_improvement