fix `Synthesized` for the new `synthesis()`

This commit is contained in:
Eren Gölge 2021-05-27 11:39:34 +02:00
parent 73bf9673ed
commit c680a07a20
2 changed files with 4 additions and 38 deletions

View File

@ -78,42 +78,6 @@ def run_model_torch(model,
'x_vector': x_vector,
'style_mel': style_mel
})
# elif "glow" in CONFIG.model.lower():
# inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
# if hasattr(model, "module"):
# # distributed model
# postnet_output, _, _, _, alignments, _, _ = model.module.inference(
# inputs,
# inputs_lengths,
# g=speaker_id if speaker_id is not None else speaker_embeddings)
# else:
# postnet_output, _, _, _, alignments, _, _ = model.inference(
# inputs,
# inputs_lengths,
# g=speaker_id if speaker_id is not None else speaker_embeddings)
# postnet_output = postnet_output.permute(0, 2, 1)
# # these only belong to tacotron models.
# decoder_output = None
# stop_tokens = None
# elif CONFIG.model.lower() in ["speedy_speech", "align_tts"]:
# inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
# if hasattr(model, "module"):
# # distributed model
# postnet_output, alignments = model.module.inference(
# inputs,
# inputs_lengths,
# g=speaker_id if speaker_id is not None else speaker_embeddings)
# else:
# postnet_output, alignments = model.inference(
# inputs,
# inputs_lengths,
# g=speaker_id if speaker_id is not None else speaker_embeddings)
# postnet_output = postnet_output.permute(0, 2, 1)
# # these only belong to tacotron models.
# decoder_output = None
# stop_tokens = None
# else:
# raise ValueError("[!] Unknown model name.")
return outputs

View File

@ -222,7 +222,7 @@ class Synthesizer(object):
for sen in sens:
# synthesize voice
waveform, _, _, mel_postnet_spec, _, _ = synthesis(
outputs = synthesis(
model=self.tts_model,
text=sen,
CONFIG=self.tts_config,
@ -232,8 +232,10 @@ class Synthesizer(object):
style_wav=style_wav,
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
use_griffin_lim=use_gl,
speaker_embedding=speaker_embedding,
x_vector=speaker_embedding,
)
waveform = outputs['wav']
mel_postnet_spec = outputs['model_outputs']
if not use_gl:
# denormalize tts output based on tts audio config
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T