diff --git a/mozilla_voice_tts/tts/models/tacotron_abstract.py b/mozilla_voice_tts/tts/models/tacotron_abstract.py index bc794d49..13c3e948 100644 --- a/mozilla_voice_tts/tts/models/tacotron_abstract.py +++ b/mozilla_voice_tts/tts/models/tacotron_abstract.py @@ -175,13 +175,7 @@ class TacotronAbstract(ABC, nn.Module): gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) gst_outputs = gst_outputs + gst_outputs_att * v_amplifier elif style_input is None: - query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) - _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - for k_token in range(self.gst_style_tokens): - key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) - gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) - gst_outputs = gst_outputs + gst_outputs_att * 0 else: gst_outputs = self.gst_layer(style_input) embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)