diff --git a/layers/attention.py b/layers/attention.py index 4326a712..2a0cec3d 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -69,10 +69,18 @@ class LocationSensitiveAttention(nn.Module): class AttentionRNN(nn.Module): - def __init__(self, out_dim, annot_dim, memory_dim): + def __init__(self, out_dim, annot_dim, memory_dim, align_model): super(AttentionRNN, self).__init__() self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim) - self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim) + # pick bahdanau or location sensitive attention + if align_model == 'b': + self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim) + if align_model == 'ls': + self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim) + else: + raise RuntimeError(" Wrong alignment model name: {}. Use\ + 'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model)) + def forward(self, memory, context, rnn_state, annotations, attention_vec, mask=None, annotations_lengths=None): @@ -88,7 +96,10 @@ class AttentionRNN(nn.Module): # Alignment # (batch, max_time) # e_{ij} = a(s_{i-1}, h_j) - alignment = self.alignment_model(annotations, rnn_output, attention_vec) + if attnetion_vec is None: + alignment = self.alignment_model(annotations, rnn_output) + else: + alignment = self.alignment_model(annotations, rnn_output, attention_vec) # TODO: needs recheck. if mask is not None: diff --git a/layers/tacotron.py b/layers/tacotron.py index 809b78ae..4b4326fc 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -212,12 +212,10 @@ class Decoder(nn.Module): eps (float): threshold for detecting the end of a sentence. """ - def __init__(self, in_features, memory_dim, r, eps=0, mode='train'): + def __init__(self, in_features, memory_dim, r): super(Decoder, self).__init__() - self.mode = mode self.max_decoder_steps = 200 self.memory_dim = memory_dim - self.eps = eps self.r = r # memory -> |Prenet| -> processed_memory self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) @@ -268,8 +266,8 @@ class Decoder(nn.Module): for _ in range(len(self.decoder_rnns))] current_context_vec = inputs.data.new(B, 256).zero_() stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() - attention_vec = inputs.data.new(B, T).zero_() - attention_vec_cum = inputs.data.new(B, T).zero_() + attention = inputs.data.new(B, T).zero_() + attention_cum = inputs.data.new(B, T).zero_() # Time first (T_decoder, B, memory_dim) if memory is not None: memory = memory.transpose(0, 1) @@ -287,12 +285,12 @@ class Decoder(nn.Module): # Prenet processed_memory = self.prenet(memory_input) # Attention RNN - attention_vec_cat = torch.cat((attention_vec.unsqueeze(1), - attention_vec_cum.unsqueeze(1) / (t + 1)), + attention_cat = torch.cat((attention.unsqueeze(1), + attention_cum.unsqueeze(1) / (t + 1)), dim=1) attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( - processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat) - attention_vec_cum += attention_vec + processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_cat) + attention_cum += attention # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( torch.cat((attention_rnn_hidden, current_context_vec), -1)) @@ -333,6 +331,7 @@ class Decoder(nn.Module): class StopNet(nn.Module): def __init__(self, r, memory_dim): + """Predicts the stop token to stop the decoder at testing time""" super(StopNet, self).__init__() self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r) self.relu = nn.ReLU()