diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 6fe60fa0..f6442800 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -448,7 +448,8 @@ class Vits(BaseTTS): g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] # language embedding - if self.args.use_language_embedding: + lang_emb=None + if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) @@ -530,6 +531,7 @@ class Vits(BaseTTS): g = self.emb_g(sid).unsqueeze(-1) # language embedding + lang_emb=None if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1)