diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 5e230bbd..d5c23ccd 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -124,7 +124,7 @@ def format_data(data): ) @torch.no_grad() -def inference(model_name, model, config, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None): +def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None): if model_name == "glow_tts": mel_input = mel_input.permute(0, 2, 1) # B x D x T speaker_c = None @@ -133,35 +133,22 @@ def inference(model_name, model, config, ap, text_input, text_lengths, mel_input elif speaker_embeddings is not None: speaker_c = speaker_embeddings - model_output, _, _, _, _, _, _ = model.inference_with_MAS( + model_output, *_ = model.inference_with_MAS( text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c ) model_output = model_output.transpose(1, 2).detach().cpu().numpy() elif "tacotron" in model_name: - if config.bidirectional_decoder or config.double_decoder_consistency: - ( - _, - postnet_outputs, - _, - _, - _, - _, - ) = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings - ) - else: - _, postnet_outputs, _, _ = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings - ) + _, postnet_outputs, *_ = model( + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) # normalize tacotron output if model_name == "tacotron": mel_specs = [] postnet_outputs = postnet_outputs.data.cpu().numpy() for b in range(postnet_outputs.shape[0]): postnet_output = postnet_outputs[b] - mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda()) - model_output = torch.stack(mel_specs) + mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T)) + model_output = torch.stack(mel_specs).cpu().numpy() elif model_name == "tacotron2": model_output = postnet_outputs.detach().cpu().numpy() @@ -186,7 +173,7 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals item_idx, ) = format_data(data) - model_output = inference(c.model.lower(), model, c, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings) + model_output = inference(c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings) for idx in range(text_input.shape[0]): wav_file_path = item_idx[idx]