diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index fded8f87..a5db64e9 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -277,6 +277,7 @@ class Tacotron2(TacotronAbstract): if self.num_speakers > 1: if not self.embeddings_per_sample: speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + speaker_embeddings = torch.unsqueeze(speaker_embeddings, 0).transpose(1, 2) encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(encoder_outputs)