extracted id to torch code

This commit is contained in:
Thomas Werkmeister 2019-07-02 14:40:01 +02:00
parent 81c5df71f6
commit d23e29ea1f
1 changed files with 9 additions and 4 deletions

View File

@ -68,6 +68,13 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
return wav return wav
def id_to_torch(speaker_id):
if speaker_id is not None:
speaker_id = np.asarray(speaker_id)
speaker_id = torch.from_numpy(speaker_id).unsqueeze(0)
return speaker_id
def synthesis(model, def synthesis(model,
text, text,
CONFIG, CONFIG,
@ -100,9 +107,7 @@ def synthesis(model,
style_mel = compute_style_mel(style_wav, ap, use_cuda) style_mel = compute_style_mel(style_wav, ap, use_cuda)
# preprocess the given text # preprocess the given text
inputs = text_to_seqvec(text, CONFIG, use_cuda) inputs = text_to_seqvec(text, CONFIG, use_cuda)
if speaker_id is not None: speaker_id = id_to_torch(speaker_id)
speaker_id = np.asarray(speaker_id)
speaker_id = torch.from_numpy(speaker_id).unsqueeze(0)
if use_cuda: if use_cuda:
speaker_id.cuda() speaker_id.cuda()
# synthesize voice # synthesize voice
@ -116,4 +121,4 @@ def synthesis(model,
# trim silence # trim silence
if trim_silence: if trim_silence:
wav = trim_silence(wav) wav = trim_silence(wav)
return wav, alignment, decoder_output, postnet_output, stop_tokens return wav, alignment, decoder_output, postnet_output, stop_tokens