No need to query every token when none were passed

This commit is contained in:
SanjaESC 2020-07-10 12:46:43 +02:00 committed by thllwg
parent b71f31eae4
commit 998f33a104
2 changed files with 1 additions and 7 deletions

View File

@ -136,7 +136,7 @@
"gst_style_input": null, // Condition the style input either on a
// -> wave file [path to wave] or
// -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15}
// with the dictionary being len(dict) == len(gst_style_tokens).
// with the dictionary being len(dict) <= len(gst_style_tokens).
"gst_embedding_dim": 512,
"gst_num_heads": 4,
"gst_style_tokens": 10

View File

@ -175,13 +175,7 @@ class TacotronAbstract(ABC, nn.Module):
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
for k_token in range(self.gst_style_tokens):
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * 0
else:
gst_outputs = self.gst_layer(style_input)
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)