inference for SS

This commit is contained in:
erogol 2020-12-28 13:53:31 +01:00
parent 30788960a8
commit 13c6665c92
1 changed files with 11 additions and 0 deletions

View File

@ -71,6 +71,17 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
# these only belong to tacotron models.
decoder_output = None
stop_tokens = None
elif 'speedy_speech' in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, 'module'):
# distributed model
postnet_output, alignments= model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
else:
postnet_output, alignments= model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None
stop_tokens = None
return decoder_output, postnet_output, alignments, stop_tokens