mirror of https://github.com/coqui-ai/TTS.git
create inference function
This commit is contained in:
parent
20e42a3381
commit
446b1da936
|
@ -124,6 +124,49 @@ def format_data(data):
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None):
|
||||||
|
if model_name == "glow_tts":
|
||||||
|
mel_input = mel_input.permute(0, 2, 1) # B x D x T
|
||||||
|
speaker_c = None
|
||||||
|
if speaker_ids is not None:
|
||||||
|
speaker_c = speaker_ids
|
||||||
|
elif speaker_embeddings is not None:
|
||||||
|
speaker_c = speaker_embeddings
|
||||||
|
|
||||||
|
model_output, _, _, _, _, _, _ = model.inference_with_MAS(
|
||||||
|
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
|
||||||
|
)
|
||||||
|
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
||||||
|
|
||||||
|
elif "tacotron" in model_name:
|
||||||
|
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||||
|
(
|
||||||
|
_,
|
||||||
|
postnet_outputs,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = model(
|
||||||
|
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_, postnet_outputs, _, _ = model(
|
||||||
|
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
|
||||||
|
)
|
||||||
|
# normalize tacotron output
|
||||||
|
if model_name == "tacotron":
|
||||||
|
mel_specs = []
|
||||||
|
postnet_outputs = postnet_outputs.data.cpu().numpy()
|
||||||
|
for b in range(postnet_outputs.shape[0]):
|
||||||
|
postnet_output = postnet_outputs[b]
|
||||||
|
mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())
|
||||||
|
model_output = torch.stack(mel_specs)
|
||||||
|
|
||||||
|
elif model_name == "tacotron2":
|
||||||
|
model_output = postnet_outputs.detach().cpu().numpy()
|
||||||
|
return model_output
|
||||||
|
|
||||||
def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"):
|
def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"):
|
||||||
model.eval()
|
model.eval()
|
||||||
export_metadata = []
|
export_metadata = []
|
||||||
|
@ -143,46 +186,7 @@ def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=Fals
|
||||||
item_idx,
|
item_idx,
|
||||||
) = format_data(data)
|
) = format_data(data)
|
||||||
|
|
||||||
if c.model.lower() == "glow_tts":
|
model_output = inference(c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings)
|
||||||
mel_input = mel_input.permute(0, 2, 1) # B x D x T
|
|
||||||
speaker_c = None
|
|
||||||
if speaker_ids is not None:
|
|
||||||
speaker_c = speaker_ids
|
|
||||||
elif speaker_embeddings is not None:
|
|
||||||
speaker_c = speaker_embeddings
|
|
||||||
|
|
||||||
model_output, _, _, _, _, _, _ = model.inference_with_MAS(
|
|
||||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
|
|
||||||
)
|
|
||||||
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
|
||||||
|
|
||||||
elif "tacotron" in c.model.lower():
|
|
||||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
|
||||||
(
|
|
||||||
_,
|
|
||||||
postnet_outputs,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) = model(
|
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_, postnet_outputs, _, _ = model(
|
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
|
|
||||||
)
|
|
||||||
# normalize tacotron output
|
|
||||||
if c.model.lower() == "tacotron":
|
|
||||||
mel_specs = []
|
|
||||||
postnet_outputs = postnet_outputs.data.cpu().numpy()
|
|
||||||
for b in range(postnet_outputs.shape[0]):
|
|
||||||
postnet_output = postnet_outputs[b]
|
|
||||||
mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())
|
|
||||||
model_output = torch.stack(mel_specs)
|
|
||||||
|
|
||||||
elif c.model.lower() == "tacotron2":
|
|
||||||
model_output = postnet_outputs.detach().cpu().numpy()
|
|
||||||
|
|
||||||
for idx in range(text_input.shape[0]):
|
for idx in range(text_input.shape[0]):
|
||||||
wav_file_path = item_idx[idx]
|
wav_file_path = item_idx[idx]
|
||||||
|
|
Loading…
Reference in New Issue