mirror of https://github.com/coqui-ai/TTS.git
No need to query every token when none were passed
This commit is contained in:
parent
b71f31eae4
commit
998f33a104
|
@ -136,7 +136,7 @@
|
|||
"gst_style_input": null, // Condition the style input either on a
|
||||
// -> wave file [path to wave] or
|
||||
// -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15}
|
||||
// with the dictionary being len(dict) == len(gst_style_tokens).
|
||||
// with the dictionary being len(dict) <= len(gst_style_tokens).
|
||||
"gst_embedding_dim": 512,
|
||||
"gst_num_heads": 4,
|
||||
"gst_style_tokens": 10
|
||||
|
|
|
@ -175,13 +175,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||
for k_token in range(self.gst_style_tokens):
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * 0
|
||||
else:
|
||||
gst_outputs = self.gst_layer(style_input)
|
||||
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
|
||||
|
|
Loading…
Reference in New Issue