diff --git a/config.json b/config.json index 3e054e72..3a412485 100644 --- a/config.json +++ b/config.json @@ -41,7 +41,9 @@ "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "prenet_type": "bn", // ONLY TACOTRON2 - "original" or "bn". - "use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. + "use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. + "transition_agent": true, // ONLY TACOTRON2 - enable/disable transition agent of forward attention. + "loss_masking": false, // enable / disable loss masking against the sequence padding. "batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention. "eval_batch_size":16, diff --git a/config_cluster.json b/config_cluster.json index 4f4248f3..a9a36f2a 100644 --- a/config_cluster.json +++ b/config_cluster.json @@ -42,6 +42,8 @@ "attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn". "use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. + "transition_agent": true, // ONLY TACOTRON2 - enable/disable transition agent of forward attention. + "loss_masking": false, // enable / disable loss masking against the sequence padding. "batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention. "eval_batch_size":16, diff --git a/layers/tacotron2.py b/layers/tacotron2.py index e1aa02dd..0a6982e2 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -122,13 +122,15 @@ class LocationLayer(nn.Module): class Attention(nn.Module): def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, attention_location_n_filters, attention_location_kernel_size, - windowing, norm, forward_attn): + windowing, norm, forward_attn, trans_agent): super(Attention, self).__init__() self.query_layer = Linear( attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') self.inputs_layer = Linear( embedding_dim, attention_dim, bias=False, init_gain='tanh') self.v = Linear(attention_dim, 1, bias=True) + if trans_agent: + self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True) self.location_layer = LocationLayer(attention_location_n_filters, attention_location_kernel_size, attention_dim) @@ -137,6 +139,7 @@ class Attention(nn.Module): self.win_idx = None self.norm = norm self.forward_attn = forward_attn + self.trans_agent = trans_agent def init_win_idx(self): self.win_idx = -1 @@ -160,29 +163,46 @@ class Attention(nn.Module): processed_inputs)) energies = energies.squeeze(-1) - return energies + return energies, processed_query + + def apply_windowing(self, attention): + back_win = self.win_idx - self.win_back + front_win = self.win_idx + self.win_front + if back_win > 0: + attention[:, :back_win] = -float("inf") + if front_win < inputs.shape[1]: + attention[:, front_win:] = -float("inf") + # this is a trick to solve a special problem. + # but it does not hurt. + if self.win_idx == -1: + attention[:, 0] = attention.max() + # Update the window + self.win_idx = torch.argmax(attention, 1).long()[0].item() + return attention + + def apply_forward_attention(self, inputs, alignment, processed_query): + # forward attention + prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device) + self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment + alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1) + # compute context + context = torch.bmm(alpha_norm.unsqueeze(1), inputs) + context = context.squeeze(1) + # compute transition agent + if self.trans_agent: + ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1) + self.u = torch.sigmoid(self.ta(ta_input)) + return context, alpha_norm, alignment def forward(self, attention_hidden_state, inputs, processed_inputs, attention_cat, mask): - attention = self.get_attention( + attention, processed_query = self.get_attention( attention_hidden_state, processed_inputs, attention_cat) if mask is not None: attention.data.masked_fill_(1 - mask, self._mask_value) - # Windowing if not self.training and self.windowing: - back_win = self.win_idx - self.win_back - front_win = self.win_idx + self.win_front - if back_win > 0: - attention[:, :back_win] = -float("inf") - if front_win < inputs.shape[1]: - attention[:, front_win:] = -float("inf") - # this is a trick to solve a special problem. - # but it does not hurt. - if self.win_idx == -1: - attention[:, 0] = attention.max() - # Update the window - self.win_idx = torch.argmax(attention, 1).long()[0].item() + attention = self.apply_windowing(attention) if self.norm == "softmax": alignment = torch.softmax(attention, dim=-1) elif self.norm == "sigmoid": @@ -191,14 +211,7 @@ class Attention(nn.Module): else: raise RuntimeError("Unknown value for attention norm type") if self.forward_attn: - # forward attention - prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device) - self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment - alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1) - # compute context - context = torch.bmm(alpha_norm.unsqueeze(1), inputs) - context = context.squeeze(1) - return context, alpha_norm, alignment + return self.apply_forward_attention(inputs, alignment, processed_query) else: context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) @@ -272,7 +285,7 @@ class Encoder(nn.Module): # adapted from https://github.com/NVIDIA/tacotron2/ class Decoder(nn.Module): - def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn): + def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent): super(Decoder, self).__init__() self.mel_channels = inputs_dim self.r = r @@ -292,7 +305,7 @@ class Decoder(nn.Module): self.attention_rnn_dim) self.attention_layer = Attention(self.attention_rnn_dim, in_features, - 128, 32, 31, attn_win, attn_norm, forward_attn) + 128, 32, 31, attn_win, attn_norm, forward_attn, trans_agent) self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, self.decoder_rnn_dim, 1) diff --git a/models/tacotron2.py b/models/tacotron2.py index dadc4a24..2e7c857b 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask # TODO: match function arguments with tacotron class Tacotron2(nn.Module): - def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False): + def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False, trans_agent=False): super(Tacotron2, self).__init__() self.n_mel_channels = 80 self.n_frames_per_step = r @@ -18,7 +18,7 @@ class Tacotron2(nn.Module): val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(512) - self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn) + self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent) self.postnet = Postnet(self.n_mel_channels) def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):