From 998f33a104652a2253e570b9de4431d55fb94511 Mon Sep 17 00:00:00 2001 From: SanjaESC Date: Fri, 10 Jul 2020 12:46:43 +0200 Subject: [PATCH] No need to query every token when none were passed --- config.json | 2 +- models/tacotron_abstract.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/config.json b/config.json index 9c4b2271..bf2ad383 100644 --- a/config.json +++ b/config.json @@ -136,7 +136,7 @@ "gst_style_input": null, // Condition the style input either on a // -> wave file [path to wave] or // -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15} - // with the dictionary being len(dict) == len(gst_style_tokens). + // with the dictionary being len(dict) <= len(gst_style_tokens). "gst_embedding_dim": 512, "gst_num_heads": 4, "gst_style_tokens": 10 diff --git a/models/tacotron_abstract.py b/models/tacotron_abstract.py index b7d9faf2..c868a18a 100644 --- a/models/tacotron_abstract.py +++ b/models/tacotron_abstract.py @@ -175,13 +175,7 @@ class TacotronAbstract(ABC, nn.Module): gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) gst_outputs = gst_outputs + gst_outputs_att * v_amplifier elif style_input is None: - query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) - _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - for k_token in range(self.gst_style_tokens): - key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) - gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) - gst_outputs = gst_outputs + gst_outputs_att * 0 else: gst_outputs = self.gst_layer(style_input) embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)