diff --git a/config.json b/config.json index b98d5646..91863c4c 100644 --- a/config.json +++ b/config.json @@ -1,7 +1,7 @@ { "model": "Tacotron2", // one of the model in models/ - "run_name": "ljspeech", - "run_description": "tacotron2 without bidirectional decoder", + "run_name": "ljspeech-graves", + "run_description": "tacotron2 wuth graves attention", // AUDIO PARAMETERS "audio":{ @@ -38,7 +38,7 @@ "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. "eval_batch_size":16, "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. - "gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. + "gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. "loss_masking": true, // enable / disable loss masking against the sequence padding. // VALIDATION @@ -47,6 +47,7 @@ "test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences. // OPTIMIZER + "noam_schedule": false, "grad_clip": 1, // upper limit for gradients for clipping. "epochs": 1000, // total number of epochs to train. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. @@ -60,7 +61,7 @@ "prenet_dropout": true, // enable/disable dropout at prenet. // ATTENTION - "attention_type": "original", // 'original' or 'graves' + "attention_type": "graves", // 'original' or 'graves' "attention_heads": 5, // number of attention heads (only for 'graves') "attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "windowing": false, // Enables attention windowing. Used only in eval mode. @@ -90,8 +91,8 @@ "max_seq_len": 150, // DATASET-RELATED: maximum text length // PATHS - // "output_path": "../keep/", // DATASET-RELATED: output path for all training outputs. - "output_path": "/media/erogol/data_ssd/Models/runs/", + "output_path": "/data5/rw/pit/keep/", // DATASET-RELATED: output path for all training outputs. + // "output_path": "/media/erogol/data_ssd/Models/runs/", // PHONEMES "phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder. @@ -108,8 +109,8 @@ [ { "name": "ljspeech", - // "path": "/data/ro/shared/data/keithito/LJSpeech-1.1/", - "path": "/home/erogol/Data/LJSpeech-1.1", + "path": "/data5/ro/shared/data/keithito/LJSpeech-1.1/", + // "path": "/home/erogol/Data/LJSpeech-1.1", "meta_file_train": "metadata_train.csv", "meta_file_val": "metadata_val.csv" } diff --git a/layers/common_layers.py b/layers/common_layers.py index 006aa57a..c2b042b0 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -82,6 +82,11 @@ class Prenet(nn.Module): return x +#################### +# ATTENTION MODULES +#################### + + class LocationLayer(nn.Module): def __init__(self, attention_dim, @@ -105,86 +110,6 @@ class LocationLayer(nn.Module): return processed_attention -class GravesAttention(nn.Module): - """ Graves attention as described here: - - https://arxiv.org/abs/1910.10288 - """ - COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) - - def __init__(self, query_dim, K): - super(GravesAttention, self).__init__() - self._mask_value = 0.0 - self.K = K - # self.attention_alignment = 0.05 - self.eps = 1e-5 - self.J = None - self.N_a = nn.Sequential( - nn.Linear(query_dim, query_dim, bias=True), - nn.ReLU(), - nn.Linear(query_dim, 3*K, bias=True)) - self.attention_weights = None - self.mu_prev = None - self.init_layers() - - def init_layers(self): - torch.nn.init.constant_(self.N_a[2].bias[10:15], 0.5) - torch.nn.init.constant_(self.N_a[2].bias[5:10], 10) - - def init_states(self, inputs): - if self.J is None or inputs.shape[1] > self.J.shape[-1]: - self.J = torch.arange(0, inputs.shape[1]).to(inputs.device).expand([inputs.shape[0], self.K, inputs.shape[1]]) - self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) - self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) - - # pylint: disable=R0201 - # pylint: disable=unused-argument - def preprocess_inputs(self, inputs): - return None - - def forward(self, query, inputs, processed_inputs, mask): - """ - shapes: - query: B x D_attention_rnn - inputs: B x T_in x D_encoder - processed_inputs: place_holder - mask: B x T_in - """ - gbk_t = self.N_a(query) - gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K) - - # attention model parameters - # each B x K - g_t = gbk_t[:, 0, :] - b_t = gbk_t[:, 1, :] - k_t = gbk_t[:, 2, :] - - # attention GMM parameters - inv_sig_t = torch.exp(-torch.clamp(b_t, min=-6, max=9)) # variance - mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) - g_t = torch.softmax(g_t, dim=-1) * inv_sig_t + self.eps - - # each B x K x T_in - g_t = g_t.unsqueeze(2).expand(g_t.size(0), - g_t.size(1), - inputs.size(1)) - inv_sig_t = inv_sig_t.unsqueeze(2).expand_as(g_t) - mu_t_ = mu_t.unsqueeze(2).expand_as(g_t) - j = self.J[:g_t.size(0), :, :inputs.size(1)] - - # attention weights - phi_t = g_t * torch.exp(-0.5 * inv_sig_t * (mu_t_ - j)**2) - alpha_t = self.COEF * torch.sum(phi_t, 1) - - # apply masking - if mask is not None: - alpha_t.data.masked_fill_(~mask, self._mask_value) - - context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) - self.attention_weights = alpha_t - self.mu_prev = mu_t - return context - - class OriginalAttention(nn.Module): """Following the methods proposed here: - https://arxiv.org/abs/1712.05884 @@ -364,6 +289,82 @@ class OriginalAttention(nn.Module): return context +class GravesAttention(nn.Module): + """ Graves attention as described here: + - https://arxiv.org/abs/1910.10288 + """ + COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) + + def __init__(self, query_dim, K): + super(GravesAttention, self).__init__() + self._mask_value = 0.0 + self.K = K + # self.attention_alignment = 0.05 + self.eps = 1e-5 + self.J = None + self.N_a = nn.Sequential( + nn.Linear(query_dim, query_dim, bias=True), + nn.ReLU(), + nn.Linear(query_dim, 3*K, bias=True)) + self.attention_weights = None + self.mu_prev = None + self.init_layers() + + def init_layers(self): + torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) + torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) + + def init_states(self, inputs): + if self.J is None or inputs.shape[1] > self.J.shape[-1]: + self.J = torch.arange(0, inputs.shape[1]).to(inputs.device) + self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) + self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) + + # pylint: disable=R0201 + # pylint: disable=unused-argument + def preprocess_inputs(self, inputs): + return None + + def forward(self, query, inputs, processed_inputs, mask): + """ + shapes: + query: B x D_attention_rnn + inputs: B x T_in x D_encoder + processed_inputs: place_holder + mask: B x T_in + """ + gbk_t = self.N_a(query) + gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K) + + # attention model parameters + # each B x K + g_t = gbk_t[:, 0, :] + b_t = gbk_t[:, 1, :] + k_t = gbk_t[:, 2, :] + + # attention GMM parameters + sig_t = torch.nn.functional.softplus(b_t) + self.eps + + mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) + g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps + + # each B x K x T_in + j = self.J[:inputs.size(1)] + + # attention weights + phi_t = g_t.unsqueeze(-1) * torch.exp(-0.5 * (mu_t.unsqueeze(-1) - j)**2 / (sig_t.unsqueeze(-1)**2)) + alpha_t = self.COEF * torch.sum(phi_t, 1) + + # apply masking + if mask is not None: + alpha_t.data.masked_fill_(~mask, self._mask_value) + + context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) + self.attention_weights = alpha_t + self.mu_prev = mu_t + return context + + def init_attn(attn_type, query_dim, embedding_dim, attention_dim, location_attention, attention_location_n_filters, attention_location_kernel_size, windowing, norm, forward_attn, diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb index a0d0be60..2f188620 100644 --- a/notebooks/ExtractTTSpectrogram.ipynb +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -303,9 +303,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.2" + "version": "3.7.4" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/server/server.py b/server/server.py index 2831c754..d40e2427 100644 --- a/server/server.py +++ b/server/server.py @@ -55,8 +55,11 @@ def tts(): if __name__ == '__main__': - if not config or not synthesizer: - args = create_argparser().parse_args() + args = create_argparser().parse_args() + + # Setup synthesizer from CLI args if they're specified or no embedded model + # is present. + if not config or not synthesizer or args.tts_checkpoint or args.tts_config: synthesizer = Synthesizer(args) app.run(debug=config.debug, host='0.0.0.0', port=config.port) diff --git a/server/synthesizer.py b/server/synthesizer.py index 14de9411..d8852a3e 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -53,15 +53,14 @@ class Synthesizer(object): num_speakers = 0 self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config) # load model state - map_location = None if use_cuda else torch.device('cpu') - cp = torch.load(tts_checkpoint, map_location=map_location) + cp = torch.load(tts_checkpoint, map_location=torch.device('cpu')) # load the model self.tts_model.load_state_dict(cp['model']) if use_cuda: self.tts_model.cuda() self.tts_model.eval() self.tts_model.decoder.max_decoder_steps = 3000 - if 'r' in cp and self.tts_config.model in ["Tacotron", "TacotronGST"]: + if 'r' in cp: self.tts_model.decoder.set_r(cp['r']) def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):