Skip to content
109 changes: 73 additions & 36 deletions g2p_seq2seq/g2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,12 @@ def train(self):
for i in xrange(len(train_bucket_sizes))]

# This is the training loop.
step_time, train_loss = 0.0, 0.0
current_step, num_iter_wo_improve = 0, 0
prev_train_losses, prev_valid_losses = [], []
num_iter_cover_train = int(sum(train_bucket_sizes) /
self.params.batch_size /
self.params.steps_per_checkpoint)
step_time, train_loss, window_scale = 0.0, 0.0, 1.5
current_step, iter_idx, num_epochs_last_impr, max_num_epochs = 0, 0, 0, 2
prev_train_losses, prev_valid_losses, prev_epoch_valid_losses = [], [], []
iter_per_epoch = max(1, int(sum(train_bucket_sizes) /
self.params.batch_size /
self.params.steps_per_checkpoint))
while (self.params.max_steps == 0
or self.model.global_step.eval(self.session)
<= self.params.max_steps):
Expand All @@ -232,45 +232,62 @@ def train(self):
# Print statistics for the previous steps.
train_ppx = math.exp(train_loss) if train_loss < 300 else float('inf')
print ("global step %d learning rate %.4f step-time %.2f perplexity "
"%.2f" % (self.model.global_step.eval(self.session),
"%.3f" % (self.model.global_step.eval(self.session),
self.model.learning_rate.eval(self.session),
step_time, train_ppx))
eval_loss = self.__calc_eval_loss()
eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
print(" eval: perplexity %.2f" % (eval_ppx))
print(" eval: perplexity %.3f" % (eval_ppx))
# Decrease learning rate if no improvement was seen on train set
# over last 3 times.
if (len(prev_train_losses) > 2
and train_loss > max(prev_train_losses[-3:])):
self.session.run(self.model.learning_rate_decay_op)

if (len(prev_valid_losses) > 0
and eval_loss <= min(prev_valid_losses)):
# Save checkpoint and zero timer and loss.
self.model.saver.save(self.session,
os.path.join(self.model_dir, "model"),
write_meta_graph=False)

if (len(prev_valid_losses) > 0
and eval_loss >= min(prev_valid_losses)):
num_iter_wo_improve += 1
else:
num_iter_wo_improve = 0

if num_iter_wo_improve > num_iter_cover_train * 2:
print("No improvement over last %d times. Training will stop after %d"
"iterations if no improvement was seen."
% (num_iter_wo_improve,
num_iter_cover_train - num_iter_wo_improve))

# Stop train if no improvement was seen on validation set
# over last 3 epochs.
if num_iter_wo_improve > num_iter_cover_train * 3:
break
#if (len(prev_valid_losses) > 0
# and eval_loss <= min(prev_valid_losses)):
# Save checkpoint and zero timer and loss.
self.model.saver.save(self.session,
os.path.join(self.model_dir, "model"),
write_meta_graph=False)

# After epoch pass, calculate average epoch loss
# and then make a decision to continue/stop training.
if (iter_idx > 0
and iter_idx % iter_per_epoch == 0):
# Calculate average validation loss during the previous epoch
epoch_eval_loss = self.__calc_epoch_loss(
prev_valid_losses[-iter_per_epoch:])
if len(prev_epoch_valid_losses) > 0:
print('Prev min epoch eval loss: %f, curr epoch eval loss: %f' %
(min(prev_epoch_valid_losses), epoch_eval_loss))
# Check if there was an improvement during last epoch
if (epoch_eval_loss < min(prev_epoch_valid_losses)):
if num_epochs_last_impr > max_num_epochs/window_scale:
max_num_epochs = int(window_scale * num_epochs_last_impr)
print('Improved during last epoch.')
prev_min_level = prev_epoch_valid_losses[-1]
num_epochs_last_impr = 0
else:
print('No improvement during last epoch.')
num_epochs_last_impr += 1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

epochs_without_improvement


print('Number of the epochs passed from the last improvement: %d'
% num_epochs_last_impr)
print('Max allowable number of epochs for improvement: %d'
% max_num_epochs)

# Stop training if no improvement was seen during last
# max_num_epochs epochs
if num_epochs_last_impr > max_num_epochs:
break

prev_epoch_valid_losses.append(round(epoch_eval_loss, 3))

prev_train_losses.append(train_loss)
prev_valid_losses.append(eval_loss)
step_time, train_loss = 0.0, 0.0
iter_idx += 1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use current step instead of iter_idx


print('Training done.')
with tf.Graph().as_default():
Expand Down Expand Up @@ -299,23 +316,43 @@ def __calc_step_loss(self, train_buckets_scale):
def __calc_eval_loss(self):
"""Run evals on development set and print their perplexity.
"""
eval_loss, num_iter_total = 0.0, 0.0
eval_loss, iter_total = 0.0, 0.0
for bucket_id in xrange(len(self._BUCKETS)):
num_iter_cover_valid = int(math.ceil(len(self.valid_set[bucket_id])/
iter_per_valid = int(math.ceil(len(self.valid_set[bucket_id])/
self.params.batch_size))
num_iter_total += num_iter_cover_valid
for batch_id in xrange(num_iter_cover_valid):
iter_total += iter_per_valid

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Count iter_total in inner loop

for batch_id in xrange(iter_per_valid):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use xrange with batch_size step.

encoder_inputs, decoder_inputs, target_weights =\
self.model.get_eval_set_batch(self.valid_set, bucket_id,
batch_id * self.params.batch_size)
_, eval_batch_loss, _ = self.model.step(self.session, encoder_inputs,
decoder_inputs, target_weights,
bucket_id, True)
eval_loss += eval_batch_loss
eval_loss = eval_loss/num_iter_total if num_iter_total > 0 else float('inf')
eval_loss = eval_loss/iter_total if iter_total > 0 else float('inf')
return eval_loss


def __calc_epoch_loss(self, epoch_losses, allow_excess_min=1.5):
"""Calculate an average loss without outliers during the epoch.

Args:
epoch_losses: list of the losses during the epoch;

Returns:
the average value of the losses without outliers during the period;
"""
epoch_loss_sum, loss_num = 0, 0
for loss in epoch_losses:
if loss < min(epoch_losses) * allow_excess_min:
epoch_loss_sum += loss
loss_num += 1
if loss_num > 0:
return epoch_loss_sum / loss_num
else:
return float(inf)


def decode_word(self, word):
"""Decode input word to sequence of phonemes.

Expand Down