From 70c83671e68b0b73bf41605ed91b7134e1cba159 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 22 Jun 2020 14:54:23 +0200 Subject: [PATCH] init coarse decoder with argument list --- models/tacotron.py | 16 ++++++++++------ models/tacotron2.py | 5 ++++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/models/tacotron.py b/models/tacotron.py index c526374a..ba42610c 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -46,11 +46,11 @@ class Tacotron(TacotronAbstract): self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) self.embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(encoder_in_features) - self.decoder = Decoder(decoder_in_features, decoder_output_dim, r, memory_size, attn_type, attn_win, - attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, forward_attn_mask, - location_attn, attn_K, separate_stopnet, - proj_speaker_dim) + self.decoder = Decoder(decoder_in_features, decoder_output_dim, r, + memory_size, attn_type, attn_win, attn_norm, + prenet_type, prenet_dropout, forward_attn, + trans_agent, forward_attn_mask, location_attn, + attn_K, separate_stopnet, proj_speaker_dim) self.postnet = PostCBHG(decoder_output_dim) self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim) @@ -74,7 +74,11 @@ class Tacotron(TacotronAbstract): self._init_backward_decoder() # setup DDC if self.double_decoder_consistency: - self._init_coarse_decoder() + self.coarse_decoder = Decoder( + decoder_in_features, decoder_output_dim, ddc_r, memory_size, + attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, + forward_attn, trans_agent, forward_attn_mask, location_attn, + attn_K, separate_stopnet, proj_speaker_dim) def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None): diff --git a/models/tacotron2.py b/models/tacotron2.py index 46b915b5..7a56212d 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -68,7 +68,10 @@ class Tacotron2(TacotronAbstract): self._init_backward_decoder() # setup DDC if self.double_decoder_consistency: - self._init_coarse_decoder() + self.coarse_decoder = Decoder(decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, attn_win, + attn_norm, prenet_type, prenet_dropout, + forward_attn, trans_agent, forward_attn_mask, + location_attn, attn_K, separate_stopnet, proj_speaker_dim) @staticmethod def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):