mirror of https://github.com/coqui-ai/TTS.git
bug fix tacotron2, decoder return order fixed
This commit is contained in:
parent
5a56a2c096
commit
89ef71ead8
|
@ -31,7 +31,7 @@
|
||||||
|
|
||||||
"reinit_layers": [],
|
"reinit_layers": [],
|
||||||
|
|
||||||
"model": "Tacotron", // one of the model in models/
|
"model": "Tacotron2", // one of the model in models/
|
||||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||||
"epochs": 1000, // total number of epochs to train.
|
"epochs": 1000, // total number of epochs to train.
|
||||||
"lr": 0.001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
"lr": 0.001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||||
|
@ -56,7 +56,7 @@
|
||||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||||
"eval_batch_size":16,
|
"eval_batch_size":16,
|
||||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||||
"gradual_training": [[0, 7, 96], [1, 5, 64], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
|
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
|
||||||
"wd": 0.000001, // Weight decay weight.
|
"wd": 0.000001, // Weight decay weight.
|
||||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||||
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
|
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||||
|
|
|
@ -238,7 +238,7 @@ class Decoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
decoder_output = decoder_output[:, :self.r * self.memory_dim]
|
decoder_output = decoder_output[:, :self.r * self.memory_dim]
|
||||||
return decoder_output, stop_token, self.attention.attention_weights
|
return decoder_output, self.attention.attention_weights, stop_token
|
||||||
|
|
||||||
def forward(self, inputs, memories, mask, speaker_embeddings=None):
|
def forward(self, inputs, memories, mask, speaker_embeddings=None):
|
||||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||||
|
@ -254,15 +254,14 @@ class Decoder(nn.Module):
|
||||||
memory = memories[len(outputs)]
|
memory = memories[len(outputs)]
|
||||||
if speaker_embeddings is not None:
|
if speaker_embeddings is not None:
|
||||||
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
||||||
mel_output, stop_token, attention_weights = self.decode(memory)
|
mel_output, attention_weights, stop_token = self.decode(memory)
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [mel_output.squeeze(1)]
|
||||||
stop_tokens += [stop_token.squeeze(1)]
|
stop_tokens += [stop_token.squeeze(1)]
|
||||||
alignments += [attention_weights]
|
alignments += [attention_weights]
|
||||||
|
|
||||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||||
outputs, stop_tokens, alignments)
|
outputs, stop_tokens, alignments)
|
||||||
|
return outputs, alignments, stop_tokens
|
||||||
return outputs, stop_tokens, alignments
|
|
||||||
|
|
||||||
def inference(self, inputs, speaker_embeddings=None):
|
def inference(self, inputs, speaker_embeddings=None):
|
||||||
memory = self.get_go_frame(inputs)
|
memory = self.get_go_frame(inputs)
|
||||||
|
@ -279,7 +278,7 @@ class Decoder(nn.Module):
|
||||||
memory = self.prenet(memory)
|
memory = self.prenet(memory)
|
||||||
if speaker_embeddings is not None:
|
if speaker_embeddings is not None:
|
||||||
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
||||||
mel_output, stop_token, alignment = self.decode(memory)
|
mel_output, alignment, stop_token = self.decode(memory)
|
||||||
stop_token = torch.sigmoid(stop_token.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [mel_output.squeeze(1)]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
|
@ -301,7 +300,7 @@ class Decoder(nn.Module):
|
||||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||||
outputs, stop_tokens, alignments)
|
outputs, stop_tokens, alignments)
|
||||||
|
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def inference_truncated(self, inputs):
|
def inference_truncated(self, inputs):
|
||||||
"""
|
"""
|
||||||
|
@ -319,7 +318,7 @@ class Decoder(nn.Module):
|
||||||
stop_flags = [True, False, False]
|
stop_flags = [True, False, False]
|
||||||
while True:
|
while True:
|
||||||
memory = self.prenet(self.memory_truncated)
|
memory = self.prenet(self.memory_truncated)
|
||||||
mel_output, stop_token, alignment = self.decode(memory)
|
mel_output, alignment, stop_token = self.decode(memory)
|
||||||
stop_token = torch.sigmoid(stop_token.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [mel_output.squeeze(1)]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
|
@ -341,7 +340,7 @@ class Decoder(nn.Module):
|
||||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||||
outputs, stop_tokens, alignments)
|
outputs, stop_tokens, alignments)
|
||||||
|
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def inference_step(self, inputs, t, memory=None):
|
def inference_step(self, inputs, t, memory=None):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -66,7 +66,7 @@ class Tacotron2(nn.Module):
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||||
speaker_ids)
|
speaker_ids)
|
||||||
decoder_outputs, stop_tokens, alignments = self.decoder(
|
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||||
encoder_outputs, mel_specs, mask)
|
encoder_outputs, mel_specs, mask)
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
postnet_outputs = decoder_outputs + postnet_outputs
|
postnet_outputs = decoder_outputs + postnet_outputs
|
||||||
|
|
Loading…
Reference in New Issue