diff --git a/layers/losses.py b/layers/losses.py index 6ccb3986..ab472519 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -61,7 +61,7 @@ class AttentionEntropyLoss(nn.Module): def forward(self, align): """ Forces attention to be more decisive by penalizing - soft attention weights + soft attention weights TODO: arguments TODO: unit_test diff --git a/layers/tacotron2.py b/layers/tacotron2.py index ea55cbed..0d7472fd 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -162,15 +162,16 @@ class Decoder(nn.Module): B = inputs.size(0) # T = inputs.size(1) if not keep_states: - self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) - self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B, - self.query_dim) - self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B, - self.decoder_rnn_dim) - self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B, - self.decoder_rnn_dim) - self.context = torch.zeros(1, device=inputs.device).repeat(B, - self.encoder_embedding_dim) + self.query = torch.zeros(1, device=inputs.device).repeat( + B, self.query_dim) + self.attention_rnn_cell_state = torch.zeros( + 1, device=inputs.device).repeat(B, self.query_dim) + self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat( + B, self.decoder_rnn_dim) + self.decoder_cell = torch.zeros(1, device=inputs.device).repeat( + B, self.decoder_rnn_dim) + self.context = torch.zeros(1, device=inputs.device).repeat( + B, self.encoder_embedding_dim) self.inputs = inputs self.processed_inputs = self.attention.inputs_layer(inputs) self.mask = mask @@ -277,7 +278,7 @@ class Decoder(nn.Module): stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): break - elif len(outputs) == self.max_decoder_steps: + if len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") break @@ -317,7 +318,7 @@ class Decoder(nn.Module): stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): break - elif len(outputs) == self.max_decoder_steps: + if len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") break