From e2439fde9abfc36b823f4d88b093180b20e7d42d Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 29 Apr 2019 11:37:01 +0200 Subject: [PATCH] make location attention optional and keep all attention weights in attention class --- layers/tacotron2.py | 95 ++++++++++++++++++++++++++---------------- models/tacotron2.py | 4 +- utils/generic_utils.py | 3 +- 3 files changed, 62 insertions(+), 40 deletions(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index df05e5ad..2fa6d06f 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -120,7 +120,7 @@ class LocationLayer(nn.Module): class Attention(nn.Module): - def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, location_attention, attention_location_n_filters, attention_location_kernel_size, windowing, norm, forward_attn, trans_agent): super(Attention, self).__init__() @@ -131,37 +131,64 @@ class Attention(nn.Module): self.v = Linear(attention_dim, 1, bias=True) if trans_agent: self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True) - self.location_layer = LocationLayer(attention_location_n_filters, - attention_location_kernel_size, - attention_dim) + if location_attention: + self.location_layer = LocationLayer(attention_location_n_filters, + attention_location_kernel_size, + attention_dim) self._mask_value = -float("inf") self.windowing = windowing self.win_idx = None self.norm = norm self.forward_attn = forward_attn self.trans_agent = trans_agent + self.location_attention = location_attention def init_win_idx(self): self.win_idx = -1 self.win_back = 2 self.win_front = 6 - def init_forward_attn_state(self, inputs): - """ - Init forward attention states - """ + def init_forward_attn(self, inputs): B = inputs.shape[0] T = inputs.shape[1] self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7 ], dim=1).to(inputs.device) self.u = (0.5 * torch.ones([B, 1])).to(inputs.device) - def get_attention(self, query, processed_inputs, attention_cat): + def init_location_attention(self, inputs): + B = inputs.shape[0] + T = inputs.shape[1] + self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_()) + + def init_states(self, inputs): + B = inputs.shape[0] + T = inputs.shape[1] + self.attention_weights = Variable(inputs.data.new(B, T).zero_()) + if self.location_attention: + self.init_location_attention(inputs) + if self.forward_attn: + self.init_forward_attn(inputs) + if self.windowing: + self.init_win_idx() + + def update_location_attention(self, alignments): + self.attention_weights_cum += alignments + + def get_location_attention(self, query, processed_inputs): + attention_cat = torch.cat((self.attention_weights.unsqueeze(1), + self.attention_weights_cum.unsqueeze(1)), + dim=1) processed_query = self.query_layer(query.unsqueeze(1)) processed_attention_weights = self.location_layer(attention_cat) energies = self.v( torch.tanh(processed_query + processed_attention_weights + - processed_inputs)) + processed_inputs)) + energies = energies.squeeze(-1) + return energies, processed_query + def get_attention(self, query, processed_inputs): + processed_query = self.query_layer(query.unsqueeze(1)) + energies = self.v( + torch.tanh(processed_query +processed_inputs)) energies = energies.squeeze(-1) return energies, processed_query @@ -192,13 +219,16 @@ class Attention(nn.Module): if self.trans_agent: ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1) self.u = torch.sigmoid(self.ta(ta_input)) - return context, self.alpha, alignment + return context, self.alpha def forward(self, attention_hidden_state, inputs, processed_inputs, - attention_cat, mask): - attention, processed_query = self.get_attention( - attention_hidden_state, processed_inputs, attention_cat) - + mask): + if self.location_attention: + attention, processed_query = self.get_location_attention( + attention_hidden_state, processed_inputs) + else: + attention, processed_query = self.get_attention( + attention_hidden_state, processed_inputs) # apply masking if mask is not None: attention.data.masked_fill_(1 - mask, self._mask_value) @@ -213,13 +243,15 @@ class Attention(nn.Module): attention).sum(dim=1).unsqueeze(1) else: raise RuntimeError("Unknown value for attention norm type") + if self.location_attention: + self.update_location_attention(alignment) # apply forward attention if enabled if self.forward_attn: - return self.apply_forward_attention(inputs, alignment, processed_query) + context, self.attention_weights = self.apply_forward_attention(inputs, alignment, processed_query) else: context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) - return context, alignment, alignment + return context class Postnet(nn.Module): @@ -289,7 +321,7 @@ class Encoder(nn.Module): # adapted from https://github.com/NVIDIA/tacotron2/ class Decoder(nn.Module): - def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent): + def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent, location_attn): super(Decoder, self).__init__() self.mel_channels = inputs_dim self.r = r @@ -308,8 +340,8 @@ class Decoder(nn.Module): self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.attention_rnn_dim) - self.attention_layer = Attention(self.attention_rnn_dim, in_features, - 128, 32, 31, attn_win, attn_norm, forward_attn, trans_agent) + self.attention_layer = Attention(self.attention_rnn_dim, in_features, 128, location_attn, + 32, 31, attn_win, attn_norm, forward_attn, trans_agent) self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, self.decoder_rnn_dim, 1) @@ -351,9 +383,6 @@ class Decoder(nn.Module): self.context = Variable( inputs.data.new(B, self.encoder_embedding_dim).zero_()) - - self.attention_weights = Variable(inputs.data.new(B, T).zero_()) - self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_()) self.inputs = inputs self.processed_inputs = self.attention_layer.inputs_layer(inputs) @@ -384,14 +413,10 @@ class Decoder(nn.Module): self.attention_cell = F.dropout( self.attention_cell, self.p_attention_dropout, self.training) - attention_cat = torch.cat((self.attention_weights.unsqueeze(1), - self.attention_weights_cum.unsqueeze(1)), - dim=1) - self.context, self.attention_weights, alignments = self.attention_layer( + self.context = self.attention_layer( self.attention_hidden, self.inputs, self.processed_inputs, - attention_cat, self.mask) + self.mask) - self.attention_weights_cum += alignments memory = torch.cat( (self.attention_hidden, self.context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( @@ -410,7 +435,7 @@ class Decoder(nn.Module): stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) gate_prediction = self.stopnet(stopnet_input) - return decoder_output, gate_prediction, self.attention_weights + return decoder_output, gate_prediction, self.attention_layer.attention_weights def forward(self, inputs, memories, mask): memory = self.get_go_frame(inputs).unsqueeze(0) @@ -419,8 +444,7 @@ class Decoder(nn.Module): memories = self.prenet(memories) self._init_states(inputs, mask=mask) - if self.attention_layer.forward_attn: - self.attention_layer.init_forward_attn_state(inputs) + self.attention_layer.init_states(inputs) outputs, stop_tokens, alignments = [], [], [] while len(outputs) < memories.size(0) - 1: @@ -441,8 +465,7 @@ class Decoder(nn.Module): self._init_states(inputs, mask=None) self.attention_layer.init_win_idx() - if self.attention_layer.forward_attn: - self.attention_layer.init_forward_attn_state(inputs) + self.attention_layer.init_states(inputs) outputs, stop_tokens, alignments, t = [], [], [], 0 stop_flags = [False, False, False] @@ -484,9 +507,7 @@ class Decoder(nn.Module): else: self._init_states(inputs, mask=None, keep_states=True) - self.attention_layer.init_win_idx() - if self.attention_layer.forward_attn: - self.attention_layer.init_forward_attn_state(inputs) + self.attention_layer.init_states(inputs) outputs, stop_tokens, alignments, t = [], [], [], 0 stop_flags = [False, False, False] stop_count = 0 diff --git a/models/tacotron2.py b/models/tacotron2.py index 2e7c857b..c492c7b1 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask # TODO: match function arguments with tacotron class Tacotron2(nn.Module): - def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False, trans_agent=False): + def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False, trans_agent=False, location_attn=True): super(Tacotron2, self).__init__() self.n_mel_channels = 80 self.n_frames_per_step = r @@ -18,7 +18,7 @@ class Tacotron2(nn.Module): val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(512) - self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent) + self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent, location_attn) self.postnet = Postnet(self.n_mel_channels) def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments): diff --git a/utils/generic_utils.py b/utils/generic_utils.py index f22c4a3a..19d93888 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -263,5 +263,6 @@ def setup_model(num_chars, c): attn_norm=c.attention_norm, prenet_type=c.prenet_type, forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent) + trans_agent=c.transition_agent, + location_attn=c.location_attn) return model \ No newline at end of file