diff --git a/mozilla_voice_tts/tts/models/tacotron2.py b/mozilla_voice_tts/tts/models/tacotron2.py index 52da2f39..59e44fb2 100644 --- a/mozilla_voice_tts/tts/models/tacotron2.py +++ b/mozilla_voice_tts/tts/models/tacotron2.py @@ -47,17 +47,22 @@ class Tacotron2(TacotronAbstract): decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim encoder_in_features = 512 if num_speakers > 1 else 512 proj_speaker_dim = 80 if num_speakers > 1 else 0 - # base layers + + # embedding layer self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) + + # speaker embedding layer if num_speakers > 1: self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) + self.encoder = Encoder(encoder_in_features) self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, proj_speaker_dim) self.postnet = Postnet(self.postnet_output_dim) + # global style token layers if self.gst: self.gst_layer = GST(num_mel=80, @@ -81,6 +86,24 @@ class Tacotron2(TacotronAbstract): mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) return mel_outputs, mel_outputs_postnet, alignments + def compute_gst(self, inputs, style_input): + """ Compute global style token """ + device = inputs.device + if isinstance(style_input, dict): + query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + for k_token, v_amplifier in style_input.items(): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * v_amplifier + elif style_input is None: + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + else: + gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable + embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1) + return inputs, embedded_gst + def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None): # compute mask for padding # B x T_in_max (boolean) diff --git a/mozilla_voice_tts/tts/models/tacotron_abstract.py b/mozilla_voice_tts/tts/models/tacotron_abstract.py index 9b2ef148..a4b8c227 100644 --- a/mozilla_voice_tts/tts/models/tacotron_abstract.py +++ b/mozilla_voice_tts/tts/models/tacotron_abstract.py @@ -164,22 +164,12 @@ class TacotronAbstract(ABC, nn.Module): self.speaker_embeddings_projected = self.speaker_project_mel( self.speaker_embeddings).squeeze(1) - def compute_gst(self, inputs, style_input): - device = inputs.device - if isinstance(style_input, dict): - query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) - _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) - gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - for k_token, v_amplifier in style_input.items(): - key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) - gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) - gst_outputs = gst_outputs + gst_outputs_att * v_amplifier - elif style_input is None: - gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - else: - gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable - embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1) - return inputs, embedded_gst + def compute_gst(self, inputs, mel_specs): + """ Compute global style token """ + # pylint: disable=not-callable + gst_outputs = self.gst_layer(mel_specs) + inputs = self._add_speaker_embedding(inputs, gst_outputs) + return inputs @staticmethod def _add_speaker_embedding(outputs, speaker_embeddings):