diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 63fe92c3..6d998492 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -175,6 +175,7 @@ def embedding_to_torch(d_vector, cuda=False): if d_vector is not None: d_vector = np.asarray(d_vector) d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) + d_vector = d_vector.squeeze().unsqueeze(0) if cuda: return d_vector.cuda() return d_vector