From c392fa4288db7975b998eaf1134b2b624b205e66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 27 May 2021 15:18:36 +0200 Subject: [PATCH] update `extract_tts_spectrograms` for the new model API --- TTS/bin/extract_tts_spectrograms.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 934055e4..e162bf4f 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -146,20 +146,22 @@ def inference( elif speaker_embeddings is not None: speaker_c = speaker_embeddings - model_output, *_ = model.inference_with_MAS( + outputs = model.inference_with_MAS( text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c ) + model_output = outputs['model_outputs'] model_output = model_output.transpose(1, 2).detach().cpu().numpy() elif "tacotron" in model_name: - _, postnet_outputs, *_ = model( + cond_input = {'speaker_ids': speaker_ids, 'x_vectors': speaker_embeddings} + outputs = model( text_input, text_lengths, mel_input, mel_lengths, - speaker_ids=speaker_ids, - speaker_embeddings=speaker_embeddings, + cond_input ) + postnet_outputs = outputs['model_outputs'] # normalize tacotron output if model_name == "tacotron": mel_specs = []