From ddaf414434f40b7375f80070751d4ce8466b9472 Mon Sep 17 00:00:00 2001 From: Eren G Date: Tue, 17 Jul 2018 16:24:39 +0200 Subject: [PATCH] attentio update --- layers/attention.py | 21 +++++++++++---------- layers/tacotron.py | 4 ++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/layers/attention.py b/layers/attention.py index 5c6c3f02..efc0e3a6 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -5,11 +5,11 @@ from utils.generic_utils import sequence_mask class BahdanauAttention(nn.Module): - def __init__(self, annot_dim, query_dim, hidden_dim): + def __init__(self, annot_dim, query_dim, attn_dim): super(BahdanauAttention, self).__init__() - self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True) - self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True) - self.v = nn.Linear(hidden_dim, 1, bias=False) + self.query_layer = nn.Linear(query_dim, attn_dim, bias=True) + self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True) + self.v = nn.Linear(attn_dim, 1, bias=False) def forward(self, annots, query): """ @@ -33,8 +33,8 @@ class BahdanauAttention(nn.Module): class LocationSensitiveAttention(nn.Module): """Location sensitive attention following https://arxiv.org/pdf/1506.07503.pdf""" - def __init__(self, annot_dim, query_dim, hidden_dim, - kernel_size=7, filters=20): + def __init__(self, annot_dim, query_dim, attn_dim, + kernel_size=31, filters=32): super(LocationSensitiveAttention, self).__init__() self.kernel_size = kernel_size self.filters = filters @@ -42,10 +42,10 @@ class LocationSensitiveAttention(nn.Module): self.loc_conv = nn.Conv1d(1, filters, kernel_size=kernel_size, stride=1, padding=padding, bias=False) - self.loc_linear = nn.Linear(filters, hidden_dim) - self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True) - self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True) - self.v = nn.Linear(hidden_dim, 1, bias=False) + self.loc_linear = nn.Linear(filters, attn_dim) + self.query_layer = nn.Linear(query_dim, attn_dim, bias=True) + self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True) + self.v = nn.Linear(attn_dim, 1, bias=False) def forward(self, annot, query, loc): """ @@ -104,6 +104,7 @@ class AttentionRNNCell(nn.Module): - annot_lens: (batch,) """ # Concat input query and previous context context + print(context.shape) rnn_input = torch.cat((memory, context), -1) # Feed it to RNN # s_i = f(y_{i-1}, c_{i}, s_{i-1}) diff --git a/layers/tacotron.py b/layers/tacotron.py index f6ab85e7..8ac5927d 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -203,7 +203,7 @@ class Decoder(nn.Module): # memory -> |Prenet| -> processed_memory self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State - self.attention_rnn = AttentionRNNCell(256, in_features, 128, align_model='ls') + self.attention_rnn = AttentionRNNCell(128, in_features, 128, align_model='ls') # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256+in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state @@ -248,7 +248,7 @@ class Decoder(nn.Module): attention_rnn_hidden = inputs.data.new(B, 256).zero_() decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_() for _ in range(len(self.decoder_rnns))] - current_context_vec = inputs.data.new(B, 256).zero_() + current_context_vec = inputs.data.new(B, 128).zero_() stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() # attention states attention = inputs.data.new(B, T).zero_()