mirror of https://github.com/coqui-ai/TTS.git
bug fix tacotron2, decoder return order fixed
This commit is contained in:
parent
5a56a2c096
commit
89ef71ead8
|
@ -31,7 +31,7 @@
|
|||
|
||||
"reinit_layers": [],
|
||||
|
||||
"model": "Tacotron", // one of the model in models/
|
||||
"model": "Tacotron2", // one of the model in models/
|
||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
|
@ -56,7 +56,7 @@
|
|||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":16,
|
||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"gradual_training": [[0, 7, 96], [1, 5, 64], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
|
||||
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||
|
|
|
@ -238,7 +238,7 @@ class Decoder(nn.Module):
|
|||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
decoder_output = decoder_output[:, :self.r * self.memory_dim]
|
||||
return decoder_output, stop_token, self.attention.attention_weights
|
||||
return decoder_output, self.attention.attention_weights, stop_token
|
||||
|
||||
def forward(self, inputs, memories, mask, speaker_embeddings=None):
|
||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||
|
@ -254,15 +254,14 @@ class Decoder(nn.Module):
|
|||
memory = memories[len(outputs)]
|
||||
if speaker_embeddings is not None:
|
||||
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
||||
mel_output, stop_token, attention_weights = self.decode(memory)
|
||||
mel_output, attention_weights, stop_token = self.decode(memory)
|
||||
outputs += [mel_output.squeeze(1)]
|
||||
stop_tokens += [stop_token.squeeze(1)]
|
||||
alignments += [attention_weights]
|
||||
|
||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||
outputs, stop_tokens, alignments)
|
||||
|
||||
return outputs, stop_tokens, alignments
|
||||
return outputs, alignments, stop_tokens
|
||||
|
||||
def inference(self, inputs, speaker_embeddings=None):
|
||||
memory = self.get_go_frame(inputs)
|
||||
|
@ -279,7 +278,7 @@ class Decoder(nn.Module):
|
|||
memory = self.prenet(memory)
|
||||
if speaker_embeddings is not None:
|
||||
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
||||
mel_output, stop_token, alignment = self.decode(memory)
|
||||
mel_output, alignment, stop_token = self.decode(memory)
|
||||
stop_token = torch.sigmoid(stop_token.data)
|
||||
outputs += [mel_output.squeeze(1)]
|
||||
stop_tokens += [stop_token]
|
||||
|
@ -301,7 +300,7 @@ class Decoder(nn.Module):
|
|||
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||
outputs, stop_tokens, alignments)
|
||||
|
||||
return outputs, stop_tokens, alignments
|
||||
return outputs, alignments, stop_tokens
|
||||
|
||||
def inference_truncated(self, inputs):
|
||||
"""
|
||||
|
@ -319,7 +318,7 @@ class Decoder(nn.Module):
|
|||
stop_flags = [True, False, False]
|
||||
while True:
|
||||
memory = self.prenet(self.memory_truncated)
|
||||
mel_output, stop_token, alignment = self.decode(memory)
|
||||
mel_output, alignment, stop_token = self.decode(memory)
|
||||
stop_token = torch.sigmoid(stop_token.data)
|
||||
outputs += [mel_output.squeeze(1)]
|
||||
stop_tokens += [stop_token]
|
||||
|
@ -341,7 +340,7 @@ class Decoder(nn.Module):
|
|||
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||
outputs, stop_tokens, alignments)
|
||||
|
||||
return outputs, stop_tokens, alignments
|
||||
return outputs, alignments, stop_tokens
|
||||
|
||||
def inference_step(self, inputs, t, memory=None):
|
||||
"""
|
||||
|
|
|
@ -66,7 +66,7 @@ class Tacotron2(nn.Module):
|
|||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
speaker_ids)
|
||||
decoder_outputs, stop_tokens, alignments = self.decoder(
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, mask)
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = decoder_outputs + postnet_outputs
|
||||
|
|
Loading…
Reference in New Issue