From ed1f648b83d11d8bf2305075194368dff304e338 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Sun, 13 Jan 2019 19:09:44 +0100 Subject: [PATCH] Enalbe attention windowing and make in configurable at model level. --- layers/attention.py | 21 ++++++++++++++++++++- layers/tacotron.py | 5 +++-- models/tacotron.py | 5 +++-- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/layers/attention.py b/layers/attention.py index ea31768a..83ccb504 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -100,7 +100,7 @@ class LocationSensitiveAttention(nn.Module): class AttentionRNNCell(nn.Module): - def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model): + def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False): r""" General Attention RNN wrapper @@ -110,10 +110,17 @@ class AttentionRNNCell(nn.Module): annot_dim (int): annotation vector feature dimension. memory_dim (int): memory vector (decoder output) feature dimension. align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment. + windowing (bool): attention windowing forcing monotonic attention. + It is only active in eval mode. """ super(AttentionRNNCell, self).__init__() self.align_model = align_model self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim) + self.windowing = windowing + if self.windowing: + self.win_back = 1 + self.win_front = 3 + self.win_idx = None # pick bahdanau or location sensitive attention if align_model == 'b': self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, @@ -138,6 +145,7 @@ class AttentionRNNCell(nn.Module): """ if t == 0: self.alignment_model.reset() + self.win_idx = 0 # Feed it to RNN # s_i = f(y_{i-1}, c_{i}, s_{i-1}) rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state) @@ -151,6 +159,17 @@ class AttentionRNNCell(nn.Module): if mask is not None: mask = mask.view(memory.size(0), -1) alignment.masked_fill_(1 - mask, -float("inf")) + # Windowing + if not self.training: + # print(" > Windowing active") + back_win = self.win_idx - self.win_back + front_win = self.win_idx + self.win_front + if back_win > 0: + alignment[:, :back_win] = -float("inf") + if front_win < memory.shape[1]: + alignment[:, front_win:] = -float("inf") + # Update the window + self.win_idx = torch.argmax(alignment,1).long()[0].item() # Normalize context weight # alignment = F.softmax(alignment, dim=-1) # alignment = 5 * alignment diff --git a/layers/tacotron.py b/layers/tacotron.py index 7b159c67..f95b1bc1 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -304,7 +304,7 @@ class Decoder(nn.Module): r (int): number of outputs per time step. """ - def __init__(self, in_features, memory_dim, r): + def __init__(self, in_features, memory_dim, r, attn_windowing): super(Decoder, self).__init__() self.r = r self.in_features = in_features @@ -318,7 +318,8 @@ class Decoder(nn.Module): rnn_dim=256, annot_dim=in_features, memory_dim=128, - align_model='ls') + align_model='ls', + windowing=attn_windowing) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256 + in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state diff --git a/models/tacotron.py b/models/tacotron.py index 5eb7dfac..844f69b3 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -12,7 +12,8 @@ class Tacotron(nn.Module): linear_dim=1025, mel_dim=80, r=5, - padding_idx=None): + padding_idx=None, + attn_windowing=False): super(Tacotron, self).__init__() self.r = r self.mel_dim = mel_dim @@ -24,7 +25,7 @@ class Tacotron(nn.Module): val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(embedding_dim) - self.decoder = Decoder(256, mel_dim, r) + self.decoder = Decoder(256, mel_dim, r, attn_windowing) self.postnet = PostCBHG(mel_dim) self.last_linear = nn.Sequential( nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),