update speaker id casting for glow-tts

This commit is contained in:
erogol 2020-12-14 16:58:47 +01:00
parent 999120ecdf
commit 639fa29261
1 changed files with 2 additions and 1 deletions

View File

@ -149,7 +149,8 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
def id_to_torch(speaker_id, cuda=False):
if speaker_id is not None:
speaker_id = np.asarray(speaker_id)
speaker_id = torch.from_numpy(speaker_id).unsqueeze(0)
# TODO: test this for tacotron models
speaker_id = torch.from_numpy(speaker_id)
if cuda:
return speaker_id.cuda()
return speaker_id