bug fix tacotron2, decoder return order fixed

This commit is contained in:
Eren Golge 2019-10-29 13:32:20 +01:00
parent 5a56a2c096
commit 89ef71ead8
3 changed files with 10 additions and 11 deletions

View File

@ -31,7 +31,7 @@
"reinit_layers": [],
"model": "Tacotron", // one of the model in models/
"model": "Tacotron2", // one of the model in models/
"grad_clip": 1, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train.
"lr": 0.001, // Initial learning rate. If Noam decay is active, maximum learning rate.
@ -56,7 +56,7 @@
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":16,
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"gradual_training": [[0, 7, 96], [1, 5, 64], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.

View File

@ -238,7 +238,7 @@ class Decoder(nn.Module):
else:
stop_token = self.stopnet(stopnet_input)
decoder_output = decoder_output[:, :self.r * self.memory_dim]
return decoder_output, stop_token, self.attention.attention_weights
return decoder_output, self.attention.attention_weights, stop_token
def forward(self, inputs, memories, mask, speaker_embeddings=None):
memory = self.get_go_frame(inputs).unsqueeze(0)
@ -254,15 +254,14 @@ class Decoder(nn.Module):
memory = memories[len(outputs)]
if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, stop_token, attention_weights = self.decode(memory)
mel_output, attention_weights, stop_token = self.decode(memory)
outputs += [mel_output.squeeze(1)]
stop_tokens += [stop_token.squeeze(1)]
alignments += [attention_weights]
outputs, stop_tokens, alignments = self._parse_outputs(
outputs, stop_tokens, alignments)
return outputs, stop_tokens, alignments
return outputs, alignments, stop_tokens
def inference(self, inputs, speaker_embeddings=None):
memory = self.get_go_frame(inputs)
@ -279,7 +278,7 @@ class Decoder(nn.Module):
memory = self.prenet(memory)
if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, stop_token, alignment = self.decode(memory)
mel_output, alignment, stop_token = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)]
stop_tokens += [stop_token]
@ -301,7 +300,7 @@ class Decoder(nn.Module):
outputs, stop_tokens, alignments = self._parse_outputs(
outputs, stop_tokens, alignments)
return outputs, stop_tokens, alignments
return outputs, alignments, stop_tokens
def inference_truncated(self, inputs):
"""
@ -319,7 +318,7 @@ class Decoder(nn.Module):
stop_flags = [True, False, False]
while True:
memory = self.prenet(self.memory_truncated)
mel_output, stop_token, alignment = self.decode(memory)
mel_output, alignment, stop_token = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)]
stop_tokens += [stop_token]
@ -341,7 +340,7 @@ class Decoder(nn.Module):
outputs, stop_tokens, alignments = self._parse_outputs(
outputs, stop_tokens, alignments)
return outputs, stop_tokens, alignments
return outputs, alignments, stop_tokens
def inference_step(self, inputs, t, memory=None):
"""

View File

@ -66,7 +66,7 @@ class Tacotron2(nn.Module):
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
decoder_outputs, stop_tokens, alignments = self.decoder(
decoder_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask)
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs