From d5c6d608848fab1643605a635cac28ab0cd930d7 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 4 Aug 2020 22:22:35 +0200 Subject: [PATCH] synthesis update for glow tts --- TTS/tts/utils/synthesis.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 76ac7909..85eeec66 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -46,16 +46,24 @@ def compute_style_mel(style_wav, ap, cuda=False): def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None): - if CONFIG.use_gst: - decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) - else: - if truncated: - decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) - else: + if 'tacotron' in CONFIG.model.lower(): + if CONFIG.use_gst: decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + else: + if truncated: + decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + else: + decoder_output, postnet_output, alignments, stop_tokens = model.inference( + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + elif 'glow' in CONFIG.model.lower(): + inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) + postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths) + postnet_output = postnet_output.permute(0, 2, 1) + # these only belong to tacotron models. + decoder_output = None + stop_tokens = None return decoder_output, postnet_output, alignments, stop_tokens