diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index a5db64e9..525eb8b3 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -255,6 +255,7 @@ class Tacotron2(TacotronAbstract): if self.num_speakers > 1: if not self.embeddings_per_sample: speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + speaker_embeddings = torch.unsqueeze(speaker_embeddings, 0).transpose(1, 2) encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)