diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 517a281e..1a77e45a 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -124,6 +124,49 @@ def format_data(data): ) @torch.no_grad() +def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None): + if model_name == "glow_tts": + mel_input = mel_input.permute(0, 2, 1) # B x D x T + speaker_c = None + if speaker_ids is not None: + speaker_c = speaker_ids + elif speaker_embeddings is not None: + speaker_c = speaker_embeddings + + model_output, _, _, _, _, _, _ = model.inference_with_MAS( + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c + ) + model_output = model_output.transpose(1, 2).detach().cpu().numpy() + + elif "tacotron" in model_name: + if c.bidirectional_decoder or c.double_decoder_consistency: + ( + _, + postnet_outputs, + _, + _, + _, + _, + ) = model( + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings + ) + else: + _, postnet_outputs, _, _ = model( + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings + ) + # normalize tacotron output + if model_name == "tacotron": + mel_specs = [] + postnet_outputs = postnet_outputs.data.cpu().numpy() + for b in range(postnet_outputs.shape[0]): + postnet_output = postnet_outputs[b] + mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda()) + model_output = torch.stack(mel_specs) + + elif model_name == "tacotron2": + model_output = postnet_outputs.detach().cpu().numpy() + return model_output + def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"): model.eval() export_metadata = [] @@ -143,46 +186,7 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals item_idx, ) = format_data(data) - if c.model.lower() == "glow_tts": - mel_input = mel_input.permute(0, 2, 1) # B x D x T - speaker_c = None - if speaker_ids is not None: - speaker_c = speaker_ids - elif speaker_embeddings is not None: - speaker_c = speaker_embeddings - - model_output, _, _, _, _, _, _ = model.inference_with_MAS( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c - ) - model_output = model_output.transpose(1, 2).detach().cpu().numpy() - - elif "tacotron" in c.model.lower(): - if c.bidirectional_decoder or c.double_decoder_consistency: - ( - _, - postnet_outputs, - _, - _, - _, - _, - ) = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings - ) - else: - _, postnet_outputs, _, _ = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings - ) - # normalize tacotron output - if c.model.lower() == "tacotron": - mel_specs = [] - postnet_outputs = postnet_outputs.data.cpu().numpy() - for b in range(postnet_outputs.shape[0]): - postnet_output = postnet_outputs[b] - mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda()) - model_output = torch.stack(mel_specs) - - elif c.model.lower() == "tacotron2": - model_output = postnet_outputs.detach().cpu().numpy() + model_output = inference(c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings) for idx in range(text_input.shape[0]): wav_file_path = item_idx[idx]