sync torch calls before logging training results

This commit is contained in:
erogol 2020-12-07 11:30:19 +01:00
parent 7505c0ba27
commit 482e725752
1 changed files with 4 additions and 1 deletions

View File

@ -186,7 +186,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# forward pass model # forward pass model
with torch.cuda.amp.autocast(enabled=c.mixed_precision): with torch.cuda.amp.autocast(enabled=c.mixed_precision):
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
# compute loss # compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
@ -273,6 +273,9 @@ def train(data_loader, model, criterion, optimizer, scheduler,
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
model_loss=loss_dict['loss']) model_loss=loss_dict['loss'])
# wait all kernels to be completed
torch.cuda.synchronize()
# Diagnostic visualizations # Diagnostic visualizations
# direct pass on model for spec predictions # direct pass on model for spec predictions
target_speaker = None if speaker_c is None else speaker_c[:1] target_speaker = None if speaker_c is None else speaker_c[:1]