diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index af5979dd..b55ba1b1 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -131,12 +131,12 @@ class GlowTts(nn.Module): def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): """ - Shapes: - x: B x T - x_lenghts: B - y: B x C x T - y_lengths: B - g: B x C or B + Shapes: + x: [B, T] + x_lenghts: B + y: [B, C, T] + y_lengths: B + g: [B, C] or B """ y_max_length = y.size(2) # norm speaker embeddings @@ -180,7 +180,6 @@ class GlowTts(nn.Module): @torch.no_grad() def inference(self, x, x_lengths, g=None): - if g is not None: if self.external_speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1)