From 4b116a2a88acdbcf5ba7d1091c71eef8de27ca8f Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 6 Mar 2019 23:46:02 +0100 Subject: [PATCH] Look for the last two attention values for stop condition and attend to the first encoder verctor if it is the first decoder iteration --- layers/tacotron2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 296ea7ec..bff9cc99 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -105,7 +105,7 @@ class Attention(nn.Module): self.win_idx = None def init_win_idx(self): - self.win_idx = 0 + self.win_idx = -1 def get_attention(self, query, processed_inputs, attention_cat): processed_query = self.query_layer(query.unsqueeze(1)) @@ -132,6 +132,10 @@ class Attention(nn.Module): attention[:, :back_win] = -float("inf") if front_win < inputs.shape[1]: attention[:, front_win:] = -float("inf") + # this is a trick to solve a special problem. + # but it does not hurt. + if self.win_idx == -1: + attention[:, 0] = attention.max() # Update the window self.win_idx = torch.argmax(attention, 1).long()[0].item() alignment = torch.sigmoid(attention) / torch.sigmoid( @@ -355,7 +359,7 @@ class Decoder(nn.Module): alignments += [alignment] stop_flags[0] = stop_flags[0] or gate_output > 0.5 - stop_flags[1] = stop_flags[1] or alignment[0, -3:].sum() > 0.5 + stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5 if all(stop_flags): break elif len(outputs) == self.max_decoder_steps: