diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 76fdb3de..d75ac64b 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -149,7 +149,8 @@ def inv_spectrogram(postnet_output, ap, CONFIG): def id_to_torch(speaker_id, cuda=False): if speaker_id is not None: speaker_id = np.asarray(speaker_id) - speaker_id = torch.from_numpy(speaker_id).unsqueeze(0) + # TODO: test this for tacotron models + speaker_id = torch.from_numpy(speaker_id) if cuda: return speaker_id.cuda() return speaker_id