diff --git a/config.json b/config.json index 39b69833..c958ebf3 100644 --- a/config.json +++ b/config.json @@ -60,6 +60,8 @@ "prenet_dropout": true, // enable/disable dropout at prenet. // ATTENTION + "attention_type": "original", // 'original' or 'graves' + "attention_heads": 5, // number of attention heads (only for 'graves') "attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "windowing": false, // Enables attention windowing. Used only in eval mode. "use_forward_attn": false, // if it uses forward attention. In general, it aligns faster. diff --git a/layers/common_layers.py b/layers/common_layers.py index acc9b4df..07f97588 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -106,25 +106,33 @@ class LocationLayer(nn.Module): class GravesAttention(nn.Module): + """ Graves attention as described here: + - https://arxiv.org/abs/1910.10288 + """ COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) - def __init__(self, query_dim, K, attention_alignment=0.05): + def __init__(self, query_dim, K): super(GravesAttention, self).__init__() - self._mask_value = -float("inf") + self._mask_value = 0.0 self.K = K - self.attention_alignment = attention_alignment + # self.attention_alignment = 0.05 self.epsilon = 1e-5 self.J = None self.N_a = nn.Sequential( nn.Linear(query_dim, query_dim//2), nn.Tanh(), nn.Linear(query_dim//2, 3*K)) - self.mu_tm1 = None - + self.attention_weights = None + self.mu_prev = None + def init_states(self, inputs): if self.J is None or inputs.shape[1] > self.J.shape[-1]: self.J = torch.arange(0, inputs.shape[1]).expand_as(torch.Tensor(inputs.shape[0], self.K, inputs.shape[1])).to(inputs.device) - self.mu_tm1 = torch.zeros(inputs.shape[0], self.K).to(inputs.device) + self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) + self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) + + def preprocess_inputs(self, inputs): + return None def forward(self, query, inputs, mask): """ @@ -143,9 +151,12 @@ class GravesAttention(nn.Module): k_t = gbk_t[:, 2, :] # attention GMM parameters - g_t = torch.softmax(g_t, dim=-1) + self.epsilon # distribution weight - sig_t = torch.exp(b_t) + self.epsilon # variance - mu_t = self.mu_tm1 + self.attention_alignment * torch.exp(k_t) # mean + # g_t = torch.softmax(g_t, dim=-1) + self.epsilon # distribution weight + # sig_t = torch.exp(b_t) + self.epsilon # variance + # mu_t = self.mu_prev + self.attention_alignment * torch.exp(k_t) # mean + sig_t = torch.pow(torch.nn.functional.softplus(b_t), 2) + mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) + g_t = (torch.softmax(g_t, dim=-1) / sig_t) * self.COEF g_t = g_t.unsqueeze(2).expand(g_t.size(0), g_t.size(1), @@ -156,27 +167,33 @@ class GravesAttention(nn.Module): # attention weights phi_t = g_t * torch.exp(-0.5 * sig_t * (mu_t_ - j)**2) - alpha_t = self.COEF * torch.sum(phi_t, 1) + alpha_t = torch.sum(phi_t, 1) # apply masking - # if mask is not None: - # alpha_t.data.masked_fill_(~mask, self._mask_value) - + if mask is not None: + alpha_t.data.masked_fill_(~mask, self._mask_value) + + context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) + self.attention_weights = alpha_t + self.mu_prev = mu_t breakpoint() - - c_t = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) - self.mu_tm1 = mu_t - return c_t, mu_t, alpha_t + return context -class Attention(nn.Module): +class OriginalAttention(nn.Module): + """Following the methods proposed here: + - https://arxiv.org/abs/1712.05884 + - https://arxiv.org/abs/1807.06736 + state masking at inference + - Using sigmoid instead of softmax normalization + - Attention windowing at inference time + """ # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init def __init__(self, query_dim, embedding_dim, attention_dim, location_attention, attention_location_n_filters, attention_location_kernel_size, windowing, norm, forward_attn, trans_agent, forward_attn_mask): - super(Attention, self).__init__() + super(OriginalAttention, self).__init__() self.query_layer = Linear( query_dim, attention_dim, bias=False, init_gain='tanh') self.inputs_layer = Linear( @@ -229,6 +246,9 @@ class Attention(nn.Module): if self.windowing: self.init_win_idx() + def preprocess_inputs(self, inputs): + return self.inputs_layer(inputs) + def update_location_attention(self, alignments): self.attention_weights_cum += alignments @@ -337,3 +357,21 @@ class Attention(nn.Module): ta_input = torch.cat([context, query.squeeze(1)], dim=-1) self.u = torch.sigmoid(self.ta(ta_input)) return context + + +def init_attn(attn_type, query_dim, embedding_dim, attention_dim, + location_attention, attention_location_n_filters, + attention_location_kernel_size, windowing, norm, forward_attn, + trans_agent, forward_attn_mask, attn_K): + if attn_type == "original": + return OriginalAttention(query_dim, embedding_dim, attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, windowing, + norm, forward_attn, trans_agent, + forward_attn_mask) + elif attn_type == "graves": + return GravesAttention(query_dim, attn_K) + else: + raise RuntimeError( + " [!] Given Attention Type '{attn_type}' is not exist.") diff --git a/layers/tacotron.py b/layers/tacotron.py index 6a474abb..20fd1e52 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -1,7 +1,7 @@ # coding: utf-8 import torch from torch import nn -from .common_layers import Prenet, Attention, Linear, GravesAttention +from .common_layers import Prenet, init_attn, Linear class BatchNormConv1d(nn.Module): @@ -263,9 +263,9 @@ class Decoder(nn.Module): # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init - def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing, + def __init__(self, in_features, memory_dim, r, memory_size, attn_type, attn_windowing, attn_norm, prenet_type, prenet_dropout, forward_attn, - trans_agent, forward_attn_mask, location_attn, + trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, speaker_embedding_dim): super(Decoder, self).__init__() self.r_init = r @@ -288,18 +288,19 @@ class Decoder(nn.Module): # attention_rnn generates queries for the attention mechanism self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim) - # self.attention = Attention(query_dim=self.query_dim, - # embedding_dim=in_features, - # attention_dim=128, - # location_attention=location_attn, - # attention_location_n_filters=32, - # attention_location_kernel_size=31, - # windowing=attn_windowing, - # norm=attn_norm, - # forward_attn=forward_attn, - # trans_agent=trans_agent, - # forward_attn_mask=forward_attn_mask) - self.attention = GravesAttention(self.query_dim, 5) + self.attention = init_attn(attn_type=attn_type, + query_dim=self.query_dim, + embedding_dim=in_features, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_windowing, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask, + attn_K=attn_K) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256 + in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state @@ -343,7 +344,7 @@ class Decoder(nn.Module): ] self.context_vec = inputs.data.new(B, self.in_features).zero_() # cache attention inputs - # self.processed_inputs = self.attention.inputs_layer(inputs) + self.processed_inputs = self.attention.preprocess_inputs(inputs) def _parse_outputs(self, outputs, attentions, stop_tokens): # Back to batch first @@ -363,7 +364,7 @@ class Decoder(nn.Module): torch.cat((processed_memory, self.context_vec), -1), self.attention_rnn_hidden) self.context_vec = self.attention( - self.attention_rnn_hidden, inputs, mask) + self.attention_rnn_hidden, inputs, self.processed_inputs, mask) # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index ecc44a25..1472bcff 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -2,7 +2,7 @@ import torch from torch.autograd import Variable from torch import nn from torch.nn import functional as F -from .common_layers import Attention, Prenet, Linear +from .common_layers import init_attn, Prenet, Linear class ConvBNBlock(nn.Module): @@ -98,9 +98,9 @@ class Encoder(nn.Module): class Decoder(nn.Module): # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init - def __init__(self, in_features, memory_dim, r, attn_win, attn_norm, + def __init__(self, in_features, memory_dim, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, - forward_attn_mask, location_attn, separate_stopnet, + forward_attn_mask, location_attn, attn_K, separate_stopnet, speaker_embedding_dim): super(Decoder, self).__init__() self.memory_dim = memory_dim @@ -128,7 +128,8 @@ class Decoder(nn.Module): self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.query_dim) - self.attention = Attention(query_dim=self.query_dim, + self.attention = init_attn(attn_type=attn_type, + query_dim=self.query_dim, embedding_dim=in_features, attention_dim=128, location_attention=location_attn, @@ -138,7 +139,8 @@ class Decoder(nn.Module): norm=attn_norm, forward_attn=forward_attn, trans_agent=trans_agent, - forward_attn_mask=forward_attn_mask) + forward_attn_mask=forward_attn_mask, + attn_K=attn_K) self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features, self.decoder_rnn_dim, 1) diff --git a/models/tacotron.py b/models/tacotron.py index d726ac03..a2d9e1c4 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -15,6 +15,7 @@ class Tacotron(nn.Module): postnet_output_dim=1025, decoder_output_dim=80, memory_size=5, + attn_type='original', attn_win=False, gst=False, attn_norm="sigmoid", @@ -24,6 +25,7 @@ class Tacotron(nn.Module): trans_agent=False, forward_attn_mask=False, location_attn=True, + attn_K=5, separate_stopnet=True, bidirectional_decoder=False): super(Tacotron, self).__init__() @@ -41,10 +43,10 @@ class Tacotron(nn.Module): self.embedding.weight.data.normal_(0, 0.3) # boilerplate model self.encoder = Encoder(encoder_dim) - self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_win, + self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, - location_attn, separate_stopnet, + location_attn, attn_K, separate_stopnet, proj_speaker_dim) if self.bidirectional_decoder: self.decoder_backward = copy.deepcopy(self.decoder) diff --git a/models/tacotron2.py b/models/tacotron2.py index 70dd31a5..c8fd9242 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -14,6 +14,7 @@ class Tacotron2(nn.Module): r, postnet_output_dim=80, decoder_output_dim=80, + attn_type='original', attn_win=False, attn_norm="softmax", prenet_type="original", @@ -22,6 +23,7 @@ class Tacotron2(nn.Module): trans_agent=False, forward_attn_mask=False, location_attn=True, + attn_K=5, separate_stopnet=True, bidirectional_decoder=False): super(Tacotron2, self).__init__() @@ -42,10 +44,10 @@ class Tacotron2(nn.Module): self.speaker_embeddings = None self.speaker_embeddings_projected = None self.encoder = Encoder(encoder_dim) - self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_win, + self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, - location_attn, separate_stopnet, proj_speaker_dim) + location_attn, attn_K, separate_stopnet, proj_speaker_dim) if self.bidirectional_decoder: self.decoder_backward = copy.deepcopy(self.decoder) self.postnet = Postnet(self.decoder_output_dim) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index bc292edd..972513d6 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -287,6 +287,7 @@ def setup_model(num_chars, num_speakers, c): decoder_output_dim=c.audio['num_mels'], gst=c.use_gst, memory_size=c.memory_size, + attn_type=c.attention_type, attn_win=c.windowing, attn_norm=c.attention_norm, prenet_type=c.prenet_type, @@ -295,6 +296,7 @@ def setup_model(num_chars, num_speakers, c): trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask, location_attn=c.location_attn, + attn_K=c.attention_heads, separate_stopnet=c.separate_stopnet, bidirectional_decoder=c.bidirectional_decoder) elif c.model.lower() == "tacotron2": @@ -303,6 +305,7 @@ def setup_model(num_chars, num_speakers, c): r=c.r, postnet_output_dim=c.audio['num_mels'], decoder_output_dim=c.audio['num_mels'], + attn_type=c.attention_type, attn_win=c.windowing, attn_norm=c.attention_norm, prenet_type=c.prenet_type, @@ -311,6 +314,7 @@ def setup_model(num_chars, num_speakers, c): trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask, location_attn=c.location_attn, + attn_K=c.attention_heads, separate_stopnet=c.separate_stopnet, bidirectional_decoder=c.bidirectional_decoder) return model