diff --git a/TTS/tts/configs/config.json b/TTS/tts/configs/config.json index 766aca40..54fcada7 100644 --- a/TTS/tts/configs/config.json +++ b/TTS/tts/configs/config.json @@ -99,7 +99,7 @@ "prenet_dropout": false, // enable/disable dropout at prenet. // TACOTRON ATTENTION - "attention_type": "original", // 'original' or 'graves' + "attention_type": "original", // 'original' , 'graves', 'dynamic_convolution' "attention_heads": 4, // number of attention heads (only for 'graves') "attention_norm": "sigmoid", // softmax or sigmoid. "windowing": false, // Enables attention windowing. Used only in eval mode. diff --git a/TTS/tts/layers/common_layers.py b/TTS/tts/layers/common_layers.py index d197bb86..c242a89a 100644 --- a/TTS/tts/layers/common_layers.py +++ b/TTS/tts/layers/common_layers.py @@ -1,6 +1,7 @@ import torch from torch import nn from torch.nn import functional as F +from scipy.stats import betabinom class Linear(nn.Module): @@ -371,6 +372,90 @@ class OriginalAttention(nn.Module): self.u = torch.sigmoid(self.ta(ta_input)) return context +class MonotonicDynamicConvolutionAttention(nn.Module): + """Dynamic convolution attention from + https://arxiv.org/pdf/1910.10288.pdf + """ + def __init__( + self, + query_dim, + embedding_dim, + attention_dim, + static_filter_dim, + static_kernel_size, + dynamic_filter_dim, + dynamic_kernel_size, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ): + super().__init__() + self._mask_value = 1e-8 + self.dynamic_filter_dim = dynamic_filter_dim + self.dynamic_kernel_size = dynamic_kernel_size + self.prior_filter_len = prior_filter_len + self.attention_weights = None + # setup key and query layers + self.query_layer = nn.Linear(query_dim, attention_dim) + self.key_layer = nn.Linear( + attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False + ) + self.static_filter_conv = nn.Conv1d( + 1, + static_filter_dim, + static_kernel_size, + padding=(static_kernel_size - 1) // 2, + bias=False, + ) + self.static_filter_layer = nn.Linear(static_filter_dim, attention_dim, bias=False) + self.dynamic_filter_layer = nn.Linear(dynamic_filter_dim, attention_dim) + self.v = nn.Linear(attention_dim, 1, bias=False) + + prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1, + alpha, beta) + self.register_buffer("prior", torch.FloatTensor(prior).flip(0)) + + # pylint: disable=unused-argument + def forward(self, query, inputs, processed_inputs, mask): + # compute prior filters + prior_filter = F.conv1d( + F.pad(self.attention_weights.unsqueeze(1), + (self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1)) + prior_filter = torch.log(prior_filter.clamp_min_(1e-6)).squeeze(1) + G = self.key_layer(torch.tanh(self.query_layer(query))) + # compute dynamic filters + dynamic_filter = F.conv1d( + self.attention_weights.unsqueeze(0), + G.view(-1, 1, self.dynamic_kernel_size), + padding=(self.dynamic_kernel_size - 1) // 2, + groups=query.size(0), + ) + dynamic_filter = dynamic_filter.view(query.size(0), self.dynamic_filter_dim, -1).transpose(1, 2) + # compute static filters + static_filter = self.static_filter_conv(self.attention_weights.unsqueeze(1)).transpose(1, 2) + alignment = self.v( + torch.tanh( + self.static_filter_layer(static_filter) + + self.dynamic_filter_layer(dynamic_filter))).squeeze(-1) + prior_filter + # compute attention weights + attention_weights = F.softmax(alignment, dim=-1) + # apply masking + if mask is not None: + attention_weights.data.masked_fill_(~mask, self._mask_value) + self.attention_weights = attention_weights + # compute context + context = torch.bmm(attention_weights.unsqueeze(1), inputs).squeeze(1) + return context + + def preprocess_inputs(self, inputs): + return None + + def init_states(self, inputs): + B = inputs.size(0) + T = inputs.size(1) + self.attention_weights = torch.zeros([B, T], device=inputs.device) + self.attention_weights[:, 0] = 1. + def init_attn(attn_type, query_dim, embedding_dim, attention_dim, location_attention, attention_location_n_filters, @@ -385,5 +470,17 @@ def init_attn(attn_type, query_dim, embedding_dim, attention_dim, forward_attn_mask) if attn_type == "graves": return GravesAttention(query_dim, attn_K) + if attn_type == "dynamic_convolution": + return MonotonicDynamicConvolutionAttention(query_dim, + embedding_dim, + attention_dim, + static_filter_dim=8, + static_kernel_size=21, + dynamic_filter_dim=8, + dynamic_kernel_size=21, + prior_filter_len=11, + alpha=0.1, + beta=0.9) + raise RuntimeError( - " [!] Given Attention Type '{attn_type}' is not exist.") + " [!] Given Attention Type '{attn_type}' is not exist.") \ No newline at end of file diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 928c9dfc..94cd5668 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -211,7 +211,7 @@ def check_config_tts(c): check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool) # attention - check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original']) + check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution']) check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int) check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax']) check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool) diff --git a/tests/inputs/test_train_config.json b/tests/inputs/test_train_config.json index 6be85314..ee0680e3 100644 --- a/tests/inputs/test_train_config.json +++ b/tests/inputs/test_train_config.json @@ -100,7 +100,7 @@ "prenet_dropout": false, // enable/disable dropout at prenet. // TACOTRON ATTENTION - "attention_type": "original", // 'original' or 'graves' + "attention_type": "original", // 'original' , 'graves', 'dynamic_convolution' "attention_heads": 4, // number of attention heads (only for 'graves') "attention_norm": "sigmoid", // softmax or sigmoid. "windowing": false, // Enables attention windowing. Used only in eval mode.