mirror of https://github.com/coqui-ai/TTS.git
sync torch calls before logging training results
This commit is contained in:
parent
7505c0ba27
commit
482e725752
|
@ -186,7 +186,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
||||||
# forward pass model
|
# forward pass model
|
||||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
|
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||||
|
@ -273,6 +273,9 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
||||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
||||||
model_loss=loss_dict['loss'])
|
model_loss=loss_dict['loss'])
|
||||||
|
|
||||||
|
# wait all kernels to be completed
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
# direct pass on model for spec predictions
|
# direct pass on model for spec predictions
|
||||||
target_speaker = None if speaker_c is None else speaker_c[:1]
|
target_speaker = None if speaker_c is None else speaker_c[:1]
|
||||||
|
|
Loading…
Reference in New Issue