diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 408f334e..deac7fc5 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -146,20 +146,22 @@ def inference( elif speaker_embeddings is not None: speaker_c = speaker_embeddings - model_output, *_ = model.inference_with_MAS( + outputs = model.inference_with_MAS( text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c ) + model_output = outputs['model_outputs'] model_output = model_output.transpose(1, 2).detach().cpu().numpy() elif "tacotron" in model_name: - _, postnet_outputs, *_ = model( + cond_input = {'speaker_ids': speaker_ids, 'x_vectors': speaker_embeddings} + outputs = model( text_input, text_lengths, mel_input, mel_lengths, - speaker_ids=speaker_ids, - speaker_embeddings=speaker_embeddings, + cond_input ) + postnet_outputs = outputs['model_outputs'] # normalize tacotron output if model_name == "tacotron": mel_specs = []