From 482e725752da48d8d3757fb0ba9c4670c122c7f4 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 7 Dec 2020 11:30:19 +0100 Subject: [PATCH] sync torch calls before logging training results --- TTS/bin/train_glow_tts.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 45e34a5e..70d0506a 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -186,7 +186,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c) # compute loss loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, @@ -273,6 +273,9 @@ def train(data_loader, model, criterion, optimizer, scheduler, save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_loss=loss_dict['loss']) + # wait all kernels to be completed + torch.cuda.synchronize() + # Diagnostic visualizations # direct pass on model for spec predictions target_speaker = None if speaker_c is None else speaker_c[:1]