diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 7dbfdd09..8f3b3804 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -170,6 +170,8 @@ class GlowTTS(BaseTTS): if g is not None: if hasattr(self, "emb_g"): # use speaker embedding layer + if not g.size(): # if is a scalar + g = g.unsqueeze(0) # unsqueeze g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] else: # use d-vector