mirror of https://github.com/coqui-ai/TTS.git
override compute_gst in tacotron2 model
This commit is contained in:
parent
77bfb881d7
commit
1436206224
|
@ -46,13 +46,15 @@ class Tacotron2(TacotronAbstract):
|
||||||
decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim
|
decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim
|
||||||
encoder_in_features = 512 if num_speakers > 1 else 512
|
encoder_in_features = 512 if num_speakers > 1 else 512
|
||||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||||
# base layers
|
|
||||||
|
# embedding layer
|
||||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||||
|
|
||||||
# speaker embedding layer
|
# speaker embedding layer
|
||||||
if num_speakers > 1:
|
if num_speakers > 1:
|
||||||
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
self.encoder = Encoder(encoder_in_features)
|
self.encoder = Encoder(encoder_in_features)
|
||||||
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
||||||
attn_norm, prenet_type, prenet_dropout,
|
attn_norm, prenet_type, prenet_dropout,
|
||||||
|
@ -83,6 +85,24 @@ class Tacotron2(TacotronAbstract):
|
||||||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
|
def compute_gst(self, inputs, style_input):
|
||||||
|
""" Compute global style token """
|
||||||
|
device = inputs.device
|
||||||
|
if isinstance(style_input, dict):
|
||||||
|
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||||
|
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||||
|
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||||
|
for k_token, v_amplifier in style_input.items():
|
||||||
|
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||||
|
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||||
|
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||||
|
elif style_input is None:
|
||||||
|
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||||
|
else:
|
||||||
|
gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable
|
||||||
|
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
|
||||||
|
return inputs, embedded_gst
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
||||||
# compute mask for padding
|
# compute mask for padding
|
||||||
# B x T_in_max (boolean)
|
# B x T_in_max (boolean)
|
||||||
|
|
|
@ -164,22 +164,12 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
self.speaker_embeddings_projected = self.speaker_project_mel(
|
self.speaker_embeddings_projected = self.speaker_project_mel(
|
||||||
self.speaker_embeddings).squeeze(1)
|
self.speaker_embeddings).squeeze(1)
|
||||||
|
|
||||||
def compute_gst(self, inputs, style_input):
|
def compute_gst(self, inputs, mel_specs):
|
||||||
device = inputs.device
|
""" Compute global style token """
|
||||||
if isinstance(style_input, dict):
|
# pylint: disable=not-callable
|
||||||
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
gst_outputs = self.gst_layer(mel_specs)
|
||||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
inputs = self._add_speaker_embedding(inputs, gst_outputs)
|
||||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
return inputs
|
||||||
for k_token, v_amplifier in style_input.items():
|
|
||||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
|
||||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
|
||||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
|
||||||
elif style_input is None:
|
|
||||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
|
||||||
else:
|
|
||||||
gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable
|
|
||||||
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
|
|
||||||
return inputs, embedded_gst
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_speaker_embedding(outputs, speaker_embeddings):
|
def _add_speaker_embedding(outputs, speaker_embeddings):
|
||||||
|
|
Loading…
Reference in New Issue