remove unused vars on extract tts spectrograms script

This commit is contained in:
Edresson 2021-05-04 19:04:13 -03:00
parent 3ecd556bbe
commit 501c8e0302
1 changed files with 7 additions and 20 deletions

View File

@ -124,7 +124,7 @@ def format_data(data):
) )
@torch.no_grad() @torch.no_grad()
def inference(model_name, model, config, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None): def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None):
if model_name == "glow_tts": if model_name == "glow_tts":
mel_input = mel_input.permute(0, 2, 1) # B x D x T mel_input = mel_input.permute(0, 2, 1) # B x D x T
speaker_c = None speaker_c = None
@ -133,35 +133,22 @@ def inference(model_name, model, config, ap, text_input, text_lengths, mel_input
elif speaker_embeddings is not None: elif speaker_embeddings is not None:
speaker_c = speaker_embeddings speaker_c = speaker_embeddings
model_output, _, _, _, _, _, _ = model.inference_with_MAS( model_output, *_ = model.inference_with_MAS(
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
) )
model_output = model_output.transpose(1, 2).detach().cpu().numpy() model_output = model_output.transpose(1, 2).detach().cpu().numpy()
elif "tacotron" in model_name: elif "tacotron" in model_name:
if config.bidirectional_decoder or config.double_decoder_consistency: _, postnet_outputs, *_ = model(
( text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
_,
postnet_outputs,
_,
_,
_,
_,
) = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
)
else:
_, postnet_outputs, _, _ = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
)
# normalize tacotron output # normalize tacotron output
if model_name == "tacotron": if model_name == "tacotron":
mel_specs = [] mel_specs = []
postnet_outputs = postnet_outputs.data.cpu().numpy() postnet_outputs = postnet_outputs.data.cpu().numpy()
for b in range(postnet_outputs.shape[0]): for b in range(postnet_outputs.shape[0]):
postnet_output = postnet_outputs[b] postnet_output = postnet_outputs[b]
mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda()) mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T))
model_output = torch.stack(mel_specs) model_output = torch.stack(mel_specs).cpu().numpy()
elif model_name == "tacotron2": elif model_name == "tacotron2":
model_output = postnet_outputs.detach().cpu().numpy() model_output = postnet_outputs.detach().cpu().numpy()
@ -186,7 +173,7 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals
item_idx, item_idx,
) = format_data(data) ) = format_data(data)
model_output = inference(c.model.lower(), model, c, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings) model_output = inference(c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings)
for idx in range(text_input.shape[0]): for idx in range(text_input.shape[0]):
wav_file_path = item_idx[idx] wav_file_path = item_idx[idx]