diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index de00f6c7..707fc9c3 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -102,7 +102,7 @@ class BaseTTS(BaseModel): config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 ) # init speaker embedding layer - if config.use_speaker_embedding and not config.use_d_vector_file: + if config.use_speaker_embedding: print(" > Init speaker_embedding layer.") self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 09537905..4d47cde1 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -578,7 +578,7 @@ class Vits(BaseTTS): outputs = {} sid, g, lid = self._set_cond_input(aux_input) # speaker embedding - if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector: + if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] # language embedding @@ -801,7 +801,7 @@ class Vits(BaseTTS): x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # speaker embedding - if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector: + if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # language embedding