synthesis update for glow tts

This commit is contained in:
erogol 2020-08-04 22:22:35 +02:00
parent 89d15bf118
commit d5c6d60884
1 changed files with 17 additions and 9 deletions

View File

@ -46,16 +46,24 @@ def compute_style_mel(style_wav, ap, cuda=False):
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
if CONFIG.use_gst:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
else:
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
else:
if 'tacotron' in CONFIG.model.lower():
if CONFIG.use_gst:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
else:
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
elif 'glow' in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None
stop_tokens = None
return decoder_output, postnet_output, alignments, stop_tokens