From 8d0d4919fdce201b138deb774ef5aec5cd60d04e Mon Sep 17 00:00:00 2001 From: SanjaESC Date: Fri, 10 Jul 2020 12:46:43 +0200 Subject: [PATCH] No need to query every token when none were passed --- mozilla_voice_tts/tts/configs/config.json | 2 +- mozilla_voice_tts/tts/models/tacotron_abstract.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/mozilla_voice_tts/tts/configs/config.json b/mozilla_voice_tts/tts/configs/config.json index 9068e2c4..8f56816e 100644 --- a/mozilla_voice_tts/tts/configs/config.json +++ b/mozilla_voice_tts/tts/configs/config.json @@ -137,7 +137,7 @@ "gst_style_input": null, // Condition the style input either on a // -> wave file [path to wave] or // -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15} - // with the dictionary being len(dict) == len(gst_style_tokens). + // with the dictionary being len(dict) <= len(gst_style_tokens). "gst_embedding_dim": 512, "gst_num_heads": 4, "gst_style_tokens": 10 diff --git a/mozilla_voice_tts/tts/models/tacotron_abstract.py b/mozilla_voice_tts/tts/models/tacotron_abstract.py index bc794d49..13c3e948 100644 --- a/mozilla_voice_tts/tts/models/tacotron_abstract.py +++ b/mozilla_voice_tts/tts/models/tacotron_abstract.py @@ -175,13 +175,7 @@ class TacotronAbstract(ABC, nn.Module): gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) gst_outputs = gst_outputs + gst_outputs_att * v_amplifier elif style_input is None: - query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) - _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - for k_token in range(self.gst_style_tokens): - key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) - gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) - gst_outputs = gst_outputs + gst_outputs_att * 0 else: gst_outputs = self.gst_layer(style_input) embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)