From 4e6596a8e105345caf837b29467f71ce225d5576 Mon Sep 17 00:00:00 2001 From: Eren G Date: Tue, 17 Jul 2018 17:01:40 +0200 Subject: [PATCH] Loc sens attention --- layers/attention.py | 10 +++++----- layers/tacotron.py | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/layers/attention.py b/layers/attention.py index efc0e3a6..5e468cfb 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -69,24 +69,25 @@ class LocationSensitiveAttention(nn.Module): class AttentionRNNCell(nn.Module): - def __init__(self, out_dim, annot_dim, memory_dim, align_model): + def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model): r""" General Attention RNN wrapper Args: out_dim (int): context vector feature dimension. + rnn_dim (int): rnn hidden state dimension. annot_dim (int): annotation vector feature dimension. memory_dim (int): memory vector (decoder autogression) feature dimension. align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment. """ super(AttentionRNNCell, self).__init__() self.align_model = align_model - self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim) + self.rnn_cell = nn.GRUCell(out_dim + memory_dim, rnn_dim) # pick bahdanau or location sensitive attention if align_model == 'b': self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim) if align_model == 'ls': - self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim) + self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim) else: raise RuntimeError(" Wrong alignment model name: {}. Use\ 'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model)) @@ -100,11 +101,10 @@ class AttentionRNNCell(nn.Module): - context: (batch, dim) - rnn_state: (batch, out_dim) - annots: (batch, max_time, annot_dim) - - atten: (batch, max_time) + - atten: (batch, 2, max_time) - annot_lens: (batch,) """ # Concat input query and previous context context - print(context.shape) rnn_input = torch.cat((memory, context), -1) # Feed it to RNN # s_i = f(y_{i-1}, c_{i}, s_{i-1}) diff --git a/layers/tacotron.py b/layers/tacotron.py index 8ac5927d..e021cd07 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -203,7 +203,8 @@ class Decoder(nn.Module): # memory -> |Prenet| -> processed_memory self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State - self.attention_rnn = AttentionRNNCell(128, in_features, 128, align_model='ls') + self.attention_rnn = AttentionRNNCell(out_dim=128, rnn_dim=256, annot_dim=in_features, + memory_dim=128, align_model='ls') # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256+in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state