From 13c6665c92668e54d6ad10cda3e87e61fb8def39 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 28 Dec 2020 13:53:31 +0100 Subject: [PATCH] inference for SS --- TTS/tts/utils/synthesis.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index d75ac64b..7e71df64 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -71,6 +71,17 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel # these only belong to tacotron models. decoder_output = None stop_tokens = None + elif 'speedy_speech' in CONFIG.model.lower(): + inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable + if hasattr(model, 'module'): + # distributed model + postnet_output, alignments= model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + else: + postnet_output, alignments= model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + postnet_output = postnet_output.permute(0, 2, 1) + # these only belong to tacotron models. + decoder_output = None + stop_tokens = None return decoder_output, postnet_output, alignments, stop_tokens