mirror of https://github.com/coqui-ai/TTS.git
No need to query every token when none were passed
This commit is contained in:
parent
84b7ab6ee6
commit
c2d8a338a1
|
@ -175,13 +175,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||
for k_token in range(self.gst_style_tokens):
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * 0
|
||||
else:
|
||||
gst_outputs = self.gst_layer(style_input)
|
||||
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
|
||||
|
|
Loading…
Reference in New Issue