glow-tts distributed fix

This commit is contained in:
erogol 2020-12-09 23:39:09 +01:00
parent 62bc171db5
commit 53679b706d
1 changed files with 6 additions and 1 deletions

View File

@ -280,7 +280,12 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# Diagnostic visualizations # Diagnostic visualizations
# direct pass on model for spec predictions # direct pass on model for spec predictions
target_speaker = None if speaker_c is None else speaker_c[:1] target_speaker = None if speaker_c is None else speaker_c[:1]
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
if hasattr(model, 'module'):
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
else:
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
spec_pred = spec_pred.permute(0, 2, 1) spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1)
const_spec = spec_pred[0].data.cpu().numpy() const_spec = spec_pred[0].data.cpu().numpy()