No need to query every token when none were passed

This commit is contained in:
SanjaESC 2020-07-10 12:46:43 +02:00 committed by erogol
parent 84b7ab6ee6
commit c2d8a338a1
1 changed files with 0 additions and 6 deletions

View File

@ -175,13 +175,7 @@ class TacotronAbstract(ABC, nn.Module):
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
for k_token in range(self.gst_style_tokens):
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * 0
else:
gst_outputs = self.gst_layer(style_input)
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)