diff --git a/layers/common_layers.py b/layers/common_layers.py index 2edf0dab..77ce4f4a 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -108,19 +108,19 @@ class LocationLayer(nn.Module): class Attention(nn.Module): # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init - def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + def __init__(self, query_dim, embedding_dim, attention_dim, location_attention, attention_location_n_filters, attention_location_kernel_size, windowing, norm, forward_attn, trans_agent, forward_attn_mask): super(Attention, self).__init__() self.query_layer = Linear( - attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') + query_dim, attention_dim, bias=False, init_gain='tanh') self.inputs_layer = Linear( embedding_dim, attention_dim, bias=False, init_gain='tanh') self.v = Linear(attention_dim, 1, bias=True) if trans_agent: self.ta = nn.Linear( - attention_rnn_dim + embedding_dim, 1, bias=True) + query_dim + embedding_dim, 1, bias=True) if location_attention: self.location_layer = LocationLayer( attention_dim, @@ -203,11 +203,12 @@ class Attention(nn.Module): def apply_forward_attention(self, inputs, alignment, query): # forward attention - prev_alpha = F.pad(self.alpha[:, :-1].clone(), - (1, 0, 0, 0)).to(inputs.device) + prev_alpha = F.pad(self.alpha[:, :-1].clone().to(inputs.device), + (1, 0, 0, 0)) # compute transition potentials - alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + - self.u * prev_alpha) + 1e-8) * alignment + alpha = ((1 - self.u) * self.alpha + + self.u * prev_alpha + + 1e-8) * alignment # force incremental alignment if not self.training and self.forward_attn_mask: _, n = prev_alpha.max(1) @@ -231,19 +232,20 @@ class Attention(nn.Module): self.u = torch.sigmoid(self.ta(ta_input)) return context, self.alpha - def forward(self, attention_hidden_state, inputs, processed_inputs, mask): + def forward(self, query, inputs, processed_inputs, mask): if self.location_attention: attention, processed_query = self.get_location_attention( - attention_hidden_state, processed_inputs) + query, processed_inputs) else: attention, processed_query = self.get_attention( - attention_hidden_state, processed_inputs) + query, processed_inputs) # apply masking if mask is not None: attention.data.masked_fill_(1 - mask, self._mask_value) # apply windowing - only in eval mode if not self.training and self.windowing: attention = self.apply_windowing(attention, inputs) + # normalize attention values if self.norm == "softmax": alignment = torch.softmax(attention, dim=-1) @@ -258,7 +260,7 @@ class Attention(nn.Module): # apply forward attention if enabled if self.forward_attn: context, self.attention_weights = self.apply_forward_attention( - inputs, alignment, attention_hidden_state) + inputs, alignment, query) else: context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) diff --git a/layers/tacotron.py b/layers/tacotron.py index b71ddbc3..068ae7cc 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -283,6 +283,7 @@ class Decoder(nn.Module): self.memory_size = memory_size if memory_size > 0 else r self.memory_dim = memory_dim self.separate_stopnet = separate_stopnet + self.query_dim = 256 # memory -> |Prenet| -> processed_memory self.prenet = Prenet( memory_dim * self.memory_size, @@ -290,18 +291,18 @@ class Decoder(nn.Module): prenet_dropout, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State - self.attention_rnn = nn.GRUCell(in_features + 128, 256) - self.attention_layer = Attention(attention_rnn_dim=256, - embedding_dim=in_features, - attention_dim=128, - location_attention=location_attn, - attention_location_n_filters=32, - attention_location_kernel_size=31, - windowing=attn_windowing, - norm=attn_norm, - forward_attn=forward_attn, - trans_agent=trans_agent, - forward_attn_mask=forward_attn_mask) + self.query_rnn = nn.GRUCell(in_features + 128, self.query_dim) + self.attention = Attention(query_dim=self.query_dim, + embedding_dim=in_features, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_windowing, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256 + in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state @@ -310,7 +311,7 @@ class Decoder(nn.Module): # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) # learn init values instead of zero init. - self.attention_rnn_init = nn.Embedding(1, 256) + self.query_rnn_init = nn.Embedding(1, 256) self.memory_init = nn.Embedding(1, self.memory_size * memory_dim) self.decoder_rnn_inits = nn.Embedding(2, 256) self.stopnet = StopNet(256 + memory_dim * r) @@ -347,18 +348,18 @@ class Decoder(nn.Module): self.memory_input = self.memory_init(inputs.data.new_zeros(B).long()) # decoder states - self.attention_rnn_hidden = self.attention_rnn_init( + self.query = self.query_rnn_init( inputs.data.new_zeros(B).long()) self.decoder_rnn_hiddens = [ self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long()) for idx in range(len(self.decoder_rnns)) ] - self.current_context_vec = inputs.data.new(B, self.in_features).zero_() + self.context_vec = inputs.data.new(B, self.in_features).zero_() # attention states self.attention = inputs.data.new(B, T).zero_() self.attention_cum = inputs.data.new(B, T).zero_() # cache attention inputs - self.processed_inputs = self.attention_layer.inputs_layer(inputs) + self.processed_inputs = self.attention.inputs_layer(inputs) def _parse_outputs(self, outputs, attentions, stop_tokens): # Back to batch first @@ -370,13 +371,15 @@ class Decoder(nn.Module): def decode(self, inputs, mask=None): # Prenet processed_memory = self.prenet(self.memory_input) + # Attention RNN - self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden) - self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask) - # Concat RNN output and attention context vector + self.query = self.query_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query) + self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask) + + # Concat query and attention context vector decoder_input = self.project_to_decoder_in( - torch.cat((self.attention_rnn_hidden, self.current_context_vec), - -1)) + torch.cat((self.query, self.context_vec), -1)) + # Pass through the decoder RNNs for idx in range(len(self.decoder_rnns)): self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx]( @@ -384,18 +387,18 @@ class Decoder(nn.Module): # Residual connection decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input decoder_output = decoder_input - del decoder_input + # predict mel vectors from decoder vectors output = self.proj_to_mel(decoder_output) output = torch.sigmoid(output) + # predict stop token stopnet_input = torch.cat([decoder_output, output], -1) - del decoder_output if self.separate_stopnet: stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) - return output, stop_token, self.attention_layer.attention_weights + return output, stop_token, self.attention.attention_weights def _update_memory_queue(self, new_memory): if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size: @@ -427,7 +430,7 @@ class Decoder(nn.Module): stop_tokens = [] t = 0 self._init_states(inputs) - self.attention_layer.init_states(inputs) + self.attention.init_states(inputs) while len(outputs) < memory.size(0): if t > 0: new_memory = memory[t - 1] @@ -453,8 +456,8 @@ class Decoder(nn.Module): stop_tokens = [] t = 0 self._init_states(inputs) - self.attention_layer.init_win_idx() - self.attention_layer.init_states(inputs) + self.attention.init_win_idx() + self.attention.init_states(inputs) while True: if t > 0: new_memory = outputs[-1] diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 802f158e..ba52abe2 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -104,7 +104,7 @@ class Decoder(nn.Module): self.r = r self.encoder_embedding_dim = in_features self.separate_stopnet = separate_stopnet - self.attention_rnn_dim = 1024 + self.query_dim = 1024 self.decoder_rnn_dim = 1024 self.prenet_dim = 256 self.max_decoder_steps = 1000 @@ -116,22 +116,22 @@ class Decoder(nn.Module): prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False) - self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, - self.attention_rnn_dim) + self.query_rnn = nn.LSTMCell(self.prenet_dim + in_features, + self.query_dim) - self.attention_layer = Attention(attention_rnn_dim=self.attention_rnn_dim, - embedding_dim=in_features, - attention_dim=128, - location_attention=location_attn, - attention_location_n_filters=32, - attention_location_kernel_size=31, - windowing=attn_win, - norm=attn_norm, - forward_attn=forward_attn, - trans_agent=trans_agent, - forward_attn_mask=forward_attn_mask) + self.attention = Attention(query_dim=self.query_dim, + embedding_dim=in_features, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_win, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask) - self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, + self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features, self.decoder_rnn_dim, 1) self.linear_projection = Linear(self.decoder_rnn_dim + in_features, @@ -145,7 +145,7 @@ class Decoder(nn.Module): bias=True, init_gain='sigmoid')) - self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim) + self.query_rnn_init = nn.Embedding(1, self.query_dim) self.go_frame_init = nn.Embedding(1, self.mel_channels * r) self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim) self.memory_truncated = None @@ -160,10 +160,10 @@ class Decoder(nn.Module): # T = inputs.size(1) if not keep_states: - self.attention_hidden = self.attention_rnn_init( + self.query = self.query_rnn_init( inputs.data.new_zeros(B).long()) - self.attention_cell = Variable( - inputs.data.new(B, self.attention_rnn_dim).zero_()) + self.query_rnn_cell_state = Variable( + inputs.data.new(B, self.query_dim).zero_()) self.decoder_hidden = self.decoder_rnn_inits( inputs.data.new_zeros(B).long()) @@ -174,7 +174,7 @@ class Decoder(nn.Module): inputs.data.new(B, self.encoder_embedding_dim).zero_()) self.inputs = inputs - self.processed_inputs = self.attention_layer.inputs_layer(inputs) + self.processed_inputs = self.attention.inputs_layer(inputs) self.mask = mask def _reshape_memory(self, memories): @@ -193,18 +193,18 @@ class Decoder(nn.Module): return outputs, stop_tokens, alignments def decode(self, memory): - cell_input = torch.cat((memory, self.context), -1) - self.attention_hidden, self.attention_cell = self.attention_rnn( - cell_input, (self.attention_hidden, self.attention_cell)) - self.attention_hidden = F.dropout( - self.attention_hidden, self.p_attention_dropout, self.training) - self.attention_cell = F.dropout( - self.attention_cell, self.p_attention_dropout, self.training) + query_input = torch.cat((memory, self.context), -1) + self.query, self.query_rnn_cell_state = self.query_rnn( + query_input, (self.query, self.query_rnn_cell_state)) + self.query = F.dropout( + self.query, self.p_attention_dropout, self.training) + self.query_rnn_cell_state = F.dropout( + self.query_rnn_cell_state, self.p_attention_dropout, self.training) - self.context = self.attention_layer(self.attention_hidden, self.inputs, - self.processed_inputs, self.mask) + self.context = self.attention(self.query, self.inputs, + self.processed_inputs, self.mask) - memory = torch.cat((self.attention_hidden, self.context), -1) + memory = torch.cat((self.query, self.context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( memory, (self.decoder_hidden, self.decoder_cell)) self.decoder_hidden = F.dropout(self.decoder_hidden, @@ -223,7 +223,7 @@ class Decoder(nn.Module): stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) - return decoder_output, stop_token, self.attention_layer.attention_weights + return decoder_output, stop_token, self.attention.attention_weights def forward(self, inputs, memories, mask): memory = self.get_go_frame(inputs).unsqueeze(0) @@ -232,7 +232,7 @@ class Decoder(nn.Module): memories = self.prenet(memories) self._init_states(inputs, mask=mask) - self.attention_layer.init_states(inputs) + self.attention.init_states(inputs) outputs, stop_tokens, alignments = [], [], [] while len(outputs) < memories.size(0) - 1: @@ -251,8 +251,8 @@ class Decoder(nn.Module): memory = self.get_go_frame(inputs) self._init_states(inputs, mask=None) - self.attention_layer.init_win_idx() - self.attention_layer.init_states(inputs) + self.attention.init_win_idx() + self.attention.init_states(inputs) outputs, stop_tokens, alignments, t = [], [], [], 0 stop_flags = [True, False, False] @@ -295,8 +295,8 @@ class Decoder(nn.Module): else: self._init_states(inputs, mask=None, keep_states=True) - self.attention_layer.init_win_idx() - self.attention_layer.init_states(inputs) + self.attention.init_win_idx() + self.attention.init_states(inputs) outputs, stop_tokens, alignments, t = [], [], [], 0 stop_flags = [True, False, False] stop_count = 0