diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index f56dfb5e..60be8fa0 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -280,7 +280,12 @@ def train(data_loader, model, criterion, optimizer, scheduler, # Diagnostic visualizations # direct pass on model for spec predictions target_speaker = None if speaker_c is None else speaker_c[:1] - spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) + + if hasattr(model, 'module'): + spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker) + else: + spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) + spec_pred = spec_pred.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1) const_spec = spec_pred[0].data.cpu().numpy()