bug fix tacotron2, decoder return order fixed

This commit is contained in:
Eren Golge 2019-10-29 13:32:20 +01:00
parent 5a56a2c096
commit 89ef71ead8
3 changed files with 10 additions and 11 deletions

View File

@ -31,7 +31,7 @@
"reinit_layers": [], "reinit_layers": [],
"model": "Tacotron", // one of the model in models/ "model": "Tacotron2", // one of the model in models/
"grad_clip": 1, // upper limit for gradients for clipping. "grad_clip": 1, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train. "epochs": 1000, // total number of epochs to train.
"lr": 0.001, // Initial learning rate. If Noam decay is active, maximum learning rate. "lr": 0.001, // Initial learning rate. If Noam decay is active, maximum learning rate.
@ -56,7 +56,7 @@
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":16, "eval_batch_size":16,
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"gradual_training": [[0, 7, 96], [1, 5, 64], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. "gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
"wd": 0.000001, // Weight decay weight. "wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step" "checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints. "save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.

View File

@ -238,7 +238,7 @@ class Decoder(nn.Module):
else: else:
stop_token = self.stopnet(stopnet_input) stop_token = self.stopnet(stopnet_input)
decoder_output = decoder_output[:, :self.r * self.memory_dim] decoder_output = decoder_output[:, :self.r * self.memory_dim]
return decoder_output, stop_token, self.attention.attention_weights return decoder_output, self.attention.attention_weights, stop_token
def forward(self, inputs, memories, mask, speaker_embeddings=None): def forward(self, inputs, memories, mask, speaker_embeddings=None):
memory = self.get_go_frame(inputs).unsqueeze(0) memory = self.get_go_frame(inputs).unsqueeze(0)
@ -254,15 +254,14 @@ class Decoder(nn.Module):
memory = memories[len(outputs)] memory = memories[len(outputs)]
if speaker_embeddings is not None: if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1) memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, stop_token, attention_weights = self.decode(memory) mel_output, attention_weights, stop_token = self.decode(memory)
outputs += [mel_output.squeeze(1)] outputs += [mel_output.squeeze(1)]
stop_tokens += [stop_token.squeeze(1)] stop_tokens += [stop_token.squeeze(1)]
alignments += [attention_weights] alignments += [attention_weights]
outputs, stop_tokens, alignments = self._parse_outputs( outputs, stop_tokens, alignments = self._parse_outputs(
outputs, stop_tokens, alignments) outputs, stop_tokens, alignments)
return outputs, alignments, stop_tokens
return outputs, stop_tokens, alignments
def inference(self, inputs, speaker_embeddings=None): def inference(self, inputs, speaker_embeddings=None):
memory = self.get_go_frame(inputs) memory = self.get_go_frame(inputs)
@ -279,7 +278,7 @@ class Decoder(nn.Module):
memory = self.prenet(memory) memory = self.prenet(memory)
if speaker_embeddings is not None: if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1) memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, stop_token, alignment = self.decode(memory) mel_output, alignment, stop_token = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data) stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)] outputs += [mel_output.squeeze(1)]
stop_tokens += [stop_token] stop_tokens += [stop_token]
@ -301,7 +300,7 @@ class Decoder(nn.Module):
outputs, stop_tokens, alignments = self._parse_outputs( outputs, stop_tokens, alignments = self._parse_outputs(
outputs, stop_tokens, alignments) outputs, stop_tokens, alignments)
return outputs, stop_tokens, alignments return outputs, alignments, stop_tokens
def inference_truncated(self, inputs): def inference_truncated(self, inputs):
""" """
@ -319,7 +318,7 @@ class Decoder(nn.Module):
stop_flags = [True, False, False] stop_flags = [True, False, False]
while True: while True:
memory = self.prenet(self.memory_truncated) memory = self.prenet(self.memory_truncated)
mel_output, stop_token, alignment = self.decode(memory) mel_output, alignment, stop_token = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data) stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)] outputs += [mel_output.squeeze(1)]
stop_tokens += [stop_token] stop_tokens += [stop_token]
@ -341,7 +340,7 @@ class Decoder(nn.Module):
outputs, stop_tokens, alignments = self._parse_outputs( outputs, stop_tokens, alignments = self._parse_outputs(
outputs, stop_tokens, alignments) outputs, stop_tokens, alignments)
return outputs, stop_tokens, alignments return outputs, alignments, stop_tokens
def inference_step(self, inputs, t, memory=None): def inference_step(self, inputs, t, memory=None):
""" """

View File

@ -66,7 +66,7 @@ class Tacotron2(nn.Module):
encoder_outputs = self.encoder(embedded_inputs, text_lengths) encoder_outputs = self.encoder(embedded_inputs, text_lengths)
encoder_outputs = self._add_speaker_embedding(encoder_outputs, encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids) speaker_ids)
decoder_outputs, stop_tokens, alignments = self.decoder( decoder_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask) encoder_outputs, mel_specs, mask)
postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs postnet_outputs = decoder_outputs + postnet_outputs