mirror of https://github.com/coqui-ai/TTS.git
change the computation of the global step
This commit is contained in:
parent
713b3df792
commit
bea9701d93
53
train.py
53
train.py
|
@ -82,7 +82,7 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
ap, epoch):
|
ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
|
@ -123,8 +123,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
stop_targets.size(1) // c.r, -1)
|
stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||||
|
|
||||||
current_step = num_iter + args.restore_step + \
|
global_step += 1
|
||||||
epoch * len(data_loader) + 1
|
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
if c.lr_decay:
|
if c.lr_decay:
|
||||||
|
@ -183,13 +182,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
if current_step % c.print_step == 0:
|
if global_step % c.print_step == 0:
|
||||||
print(
|
print(
|
||||||
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} "
|
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} "
|
||||||
"DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
"DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||||
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} "
|
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} "
|
||||||
"LoaderTime:{:.2f} LR:{:.6f}".format(
|
"LoaderTime:{:.2f} LR:{:.6f}".format(
|
||||||
num_iter, batch_n_iter, current_step, loss.item(),
|
num_iter, batch_n_iter, global_step, loss.item(),
|
||||||
postnet_loss.item(), decoder_loss.item(), stop_loss.item(),
|
postnet_loss.item(), decoder_loss.item(), stop_loss.item(),
|
||||||
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time,
|
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time,
|
||||||
loader_time, current_lr),
|
loader_time, current_lr),
|
||||||
|
@ -216,13 +215,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
"grad_norm": grad_norm,
|
"grad_norm": grad_norm,
|
||||||
"grad_norm_st": grad_norm_st,
|
"grad_norm_st": grad_norm_st,
|
||||||
"step_time": step_time}
|
"step_time": step_time}
|
||||||
tb_logger.tb_train_iter_stats(current_step, iter_stats)
|
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||||
|
|
||||||
if current_step % c.save_step == 0:
|
if global_step % c.save_step == 0:
|
||||||
if c.checkpoint:
|
if c.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model, optimizer, optimizer_st,
|
save_checkpoint(model, optimizer, optimizer_st,
|
||||||
postnet_loss.item(), OUT_PATH, current_step,
|
postnet_loss.item(), OUT_PATH, global_step,
|
||||||
epoch)
|
epoch)
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
|
@ -235,14 +234,14 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||||
"alignment": plot_alignment(align_img)
|
"alignment": plot_alignment(align_img)
|
||||||
}
|
}
|
||||||
tb_logger.tb_train_figures(current_step, figures)
|
tb_logger.tb_train_figures(global_step, figures)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
train_audio = ap.inv_spectrogram(const_spec.T)
|
train_audio = ap.inv_spectrogram(const_spec.T)
|
||||||
else:
|
else:
|
||||||
train_audio = ap.inv_mel_spectrogram(const_spec.T)
|
train_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||||
tb_logger.tb_train_audios(current_step,
|
tb_logger.tb_train_audios(global_step,
|
||||||
{'TrainAudio': train_audio},
|
{'TrainAudio': train_audio},
|
||||||
c.audio["sample_rate"])
|
c.audio["sample_rate"])
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
@ -259,7 +258,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||||
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
|
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
|
||||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||||
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
"AvgStepTime:{:.2f}".format(global_step, avg_total_loss,
|
||||||
avg_postnet_loss, avg_decoder_loss,
|
avg_postnet_loss, avg_decoder_loss,
|
||||||
avg_stop_loss, epoch_time, avg_step_time,
|
avg_stop_loss, epoch_time, avg_step_time,
|
||||||
avg_loader_time),
|
avg_loader_time),
|
||||||
|
@ -272,13 +271,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
"loss_decoder": avg_decoder_loss,
|
"loss_decoder": avg_decoder_loss,
|
||||||
"stop_loss": avg_stop_loss,
|
"stop_loss": avg_stop_loss,
|
||||||
"epoch_time": epoch_time}
|
"epoch_time": epoch_time}
|
||||||
tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
|
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||||
if c.tb_model_param_stats:
|
if c.tb_model_param_stats:
|
||||||
tb_logger.tb_model_weights(model, current_step)
|
tb_logger.tb_model_weights(model, global_step)
|
||||||
return avg_postnet_loss, current_step
|
return avg_postnet_loss, global_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, is_val=True)
|
data_loader = setup_loader(ap, is_val=True)
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
|
@ -391,14 +390,14 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||||
"alignment": plot_alignment(align_img)
|
"alignment": plot_alignment(align_img)
|
||||||
}
|
}
|
||||||
tb_logger.tb_eval_figures(current_step, eval_figures)
|
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||||
else:
|
else:
|
||||||
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||||
tb_logger.tb_eval_audios(current_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
|
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
|
||||||
|
|
||||||
# compute average losses
|
# compute average losses
|
||||||
avg_postnet_loss /= (num_iter + 1)
|
avg_postnet_loss /= (num_iter + 1)
|
||||||
|
@ -409,7 +408,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
epoch_stats = {"loss_postnet": avg_postnet_loss,
|
epoch_stats = {"loss_postnet": avg_postnet_loss,
|
||||||
"loss_decoder": avg_decoder_loss,
|
"loss_decoder": avg_decoder_loss,
|
||||||
"stop_loss": avg_stop_loss}
|
"stop_loss": avg_stop_loss}
|
||||||
tb_logger.tb_eval_stats(current_step, epoch_stats)
|
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
||||||
|
|
||||||
if args.rank == 0 and epoch > c.test_delay_epochs:
|
if args.rank == 0 and epoch > c.test_delay_epochs:
|
||||||
# test sentences
|
# test sentences
|
||||||
|
@ -422,7 +421,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
|
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
|
||||||
model, test_sentence, c, use_cuda, ap,
|
model, test_sentence, c, use_cuda, ap,
|
||||||
speaker_id=speaker_id)
|
speaker_id=speaker_id)
|
||||||
file_path = os.path.join(AUDIO_PATH, str(current_step))
|
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
file_path = os.path.join(file_path,
|
file_path = os.path.join(file_path,
|
||||||
"TestSentence_{}.wav".format(idx))
|
"TestSentence_{}.wav".format(idx))
|
||||||
|
@ -433,8 +432,8 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
except:
|
except:
|
||||||
print(" !! Error creating Test Sentence -", idx)
|
print(" !! Error creating Test Sentence -", idx)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
tb_logger.tb_test_audios(current_step, test_audios, c.audio['sample_rate'])
|
tb_logger.tb_test_audios(global_step, test_audios, c.audio['sample_rate'])
|
||||||
tb_logger.tb_test_figures(current_step, test_figures)
|
tb_logger.tb_test_figures(global_step, test_figures)
|
||||||
return avg_postnet_loss
|
return avg_postnet_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -532,19 +531,19 @@ def main(args): #pylint: disable=redefined-outer-name
|
||||||
if 'best_loss' not in locals():
|
if 'best_loss' not in locals():
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
current_step = 0
|
global_step = args.restore_step
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
# set gradual training
|
# set gradual training
|
||||||
if c.gradual_training is not None:
|
if c.gradual_training is not None:
|
||||||
r, c.batch_size = gradual_training_scheduler(current_step, c)
|
r, c.batch_size = gradual_training_scheduler(global_step, c)
|
||||||
c.r = r
|
c.r = r
|
||||||
model.decoder._set_r(r)
|
model.decoder._set_r(r)
|
||||||
print(" > Number of outputs per iteration:", model.decoder.r)
|
print(" > Number of outputs per iteration:", model.decoder.r)
|
||||||
|
|
||||||
train_loss, current_step = train(model, criterion, criterion_st,
|
train_loss, global_step = train(model, criterion, criterion_st,
|
||||||
optimizer, optimizer_st, scheduler,
|
optimizer, optimizer_st, scheduler,
|
||||||
ap, epoch)
|
ap, global_step, epoch)
|
||||||
val_loss = evaluate(model, criterion, criterion_st, ap, current_step, epoch)
|
val_loss = evaluate(model, criterion, criterion_st, ap, global_step, epoch)
|
||||||
print(
|
print(
|
||||||
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
|
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
|
||||||
train_loss, val_loss),
|
train_loss, val_loss),
|
||||||
|
@ -553,7 +552,7 @@ def main(args): #pylint: disable=redefined-outer-name
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
target_loss = val_loss
|
target_loss = val_loss
|
||||||
best_loss = save_best_model(model, optimizer, target_loss, best_loss,
|
best_loss = save_best_model(model, optimizer, target_loss, best_loss,
|
||||||
OUT_PATH, current_step, epoch)
|
OUT_PATH, global_step, epoch)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue