diff --git a/layers/tacotron.py b/layers/tacotron.py index 2e945844..496a86c8 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -223,7 +223,7 @@ class Decoder(nn.Module): self.r = r # memory -> |Prenet| -> processed_memory self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) - # processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State + # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State self.attention_rnn = AttentionRNN(256, in_features, 128) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256+in_features, 256) @@ -252,6 +252,7 @@ class Decoder(nn.Module): - memory: batch x #mels_pecs x mel_spec_dim """ B = inputs.size(0) + T = inputs.size(1) # Run greedy decoding if memory is None greedy = not self.training if memory is not None: @@ -269,11 +270,13 @@ class Decoder(nn.Module): for _ in range(len(self.decoder_rnns))] current_context_vec = inputs.data.new(B, 256).zero_() stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() + attention_vec = memory.data.new(B, T).zero_() + attention_vec_cum = memory.data.new(B, T).zero_() # Time first (T_decoder, B, memory_dim) if memory is not None: memory = memory.transpose(0, 1) outputs = [] - alignments = [] + attentions = [] stop_tokens = [] t = 0 memory_input = initial_memory @@ -286,8 +289,13 @@ class Decoder(nn.Module): # Prenet processed_memory = self.prenet(memory_input) # Attention RNN - attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn( - processed_memory, current_context_vec, attention_rnn_hidden, inputs) + attention_vec_cat = torch.cat((attention_vec.unsqueeze(1), + attention_vec_cum.unsqueeze(1)), + dim=1) + attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( + processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat) + attention_vec_cum += attention_vec + attention_vec_cum /= (t + 1) # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( torch.cat((attention_rnn_hidden, current_context_vec), -1)) @@ -304,14 +312,14 @@ class Decoder(nn.Module): # predict stop token stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden) outputs += [output] - alignments += [alignment] + attentions += [attention] stop_tokens += [stop_token] t += 1 if (not greedy and self.training) or (greedy and memory is not None): if t >= T_decoder: break else: - if t > inputs.shape[1]/2 and stop_token > 0.8: + if t > inputs.shape[1]/2 and stop_token > 0.6: break elif t > self.max_decoder_steps: print(" !! Decoder stopped with 'max_decoder_steps'. \ @@ -319,10 +327,10 @@ class Decoder(nn.Module): break assert greedy or len(outputs) == T_decoder # Back to batch first - alignments = torch.stack(alignments).transpose(0, 1) + attentions = torch.stack(attentions).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() stop_tokens = torch.stack(stop_tokens).transpose(0, 1) - return outputs, alignments, stop_tokens + return outputs, attentions, stop_tokens class StopNet(nn.Module):