From d47ba5d310377026f612bbe480c88be117c565a1 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Sat, 20 Jul 2019 12:33:21 +0200 Subject: [PATCH] gradual traning with memory queue --- config_libritts.json | 10 ++++----- layers/tacotron.py | 50 +++++++++++++++++++++++++------------------- models/tacotron.py | 2 +- train.py | 21 +++++++++++++++++++ 4 files changed, 55 insertions(+), 28 deletions(-) diff --git a/config_libritts.json b/config_libritts.json index f9a752ec..22f03dd4 100644 --- a/config_libritts.json +++ b/config_libritts.json @@ -1,6 +1,6 @@ { "run_name": "libritts-360", - "run_description": "LibriTTS 360 clean with multi speaker embedding.", + "run_description": "LibriTTS 360 gradual traning with memory queue.", "audio":{ // Audio processing parameters @@ -31,13 +31,13 @@ "reinit_layers": [], - "model": "Tacotron2", // one of the model in models/ + "model": "Tacotron", // one of the model in models/ "grad_clip": 1, // upper limit for gradients for clipping. "epochs": 1000, // total number of epochs to train. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. "lr_decay": false, // if true, Noam learning rate decaying is applied through training. "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" - "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. + "memory_size": 7, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn". "prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet. @@ -52,9 +52,9 @@ "separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER. "tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. - "batch_size": 24, // Batch size for training. Lower values than 32 might cause hard to learn attention. + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. "eval_batch_size":16, - "r": 1, // Number of frames to predict for step. + "r": 7, // Number of frames to predict for step. "wd": 0.000001, // Weight decay weight. "checkpoint": true, // If true, it saves checkpoints per "save_step" "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. diff --git a/layers/tacotron.py b/layers/tacotron.py index b71ddbc3..aa6ca4a6 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -270,12 +270,14 @@ class Decoder(nn.Module): memory_size (int): size of the past window. if <= 0 memory_size = r TODO: arguments """ + # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing, attn_norm, prenet_type, prenet_dropout, forward_attn, - trans_agent, forward_attn_mask, location_attn, separate_stopnet): + trans_agent, forward_attn_mask, location_attn, + separate_stopnet): super(Decoder, self).__init__() self.r = r self.in_features = in_features @@ -291,17 +293,18 @@ class Decoder(nn.Module): out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State self.attention_rnn = nn.GRUCell(in_features + 128, 256) - self.attention_layer = Attention(attention_rnn_dim=256, - embedding_dim=in_features, - attention_dim=128, - location_attention=location_attn, - attention_location_n_filters=32, - attention_location_kernel_size=31, - windowing=attn_windowing, - norm=attn_norm, - forward_attn=forward_attn, - trans_agent=trans_agent, - forward_attn_mask=forward_attn_mask) + self.attention_layer = Attention( + attention_rnn_dim=256, + embedding_dim=in_features, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_windowing, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256 + in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state @@ -324,6 +327,9 @@ class Decoder(nn.Module): self.proj_to_mel.weight, gain=torch.nn.init.calculate_gain('linear')) + def _set_r(self, new_r): + self.r = new_r + def _reshape_memory(self, memory): """ Reshape the spectrograms for given 'r' @@ -371,8 +377,11 @@ class Decoder(nn.Module): # Prenet processed_memory = self.prenet(self.memory_input) # Attention RNN - self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden) - self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask) + self.attention_rnn_hidden = self.attention_rnn( + torch.cat((processed_memory, self.current_context_vec), -1), + self.attention_rnn_hidden) + self.current_context_vec = self.attention_layer( + self.attention_rnn_hidden, inputs, self.processed_inputs, mask) # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( torch.cat((self.attention_rnn_hidden, self.current_context_vec), @@ -395,17 +404,14 @@ class Decoder(nn.Module): stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) + output = output[:, : self.r * self.memory_dim] return output, stop_token, self.attention_layer.attention_weights def _update_memory_queue(self, new_memory): - if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size: - self.memory_input = torch.cat([ - self.memory_input[:, self.r * self.memory_dim:].clone(), - new_memory - ], - dim=-1) - else: - self.memory_input = new_memory + self.memory_input = torch.cat([ + self.memory_input[:, self.r * self.memory_dim:].clone(), new_memory + ], + dim=-1) def forward(self, inputs, memory, mask): """ diff --git a/models/tacotron.py b/models/tacotron.py index b7f40683..18c74904 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -39,7 +39,7 @@ class Tacotron(nn.Module): self.last_linear = nn.Sequential( nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), nn.Sigmoid()) - + def forward(self, characters, text_lengths, mel_specs, speaker_ids=None): B = characters.size(0) mask = sequence_mask(text_lengths).to(characters.device) diff --git a/train.py b/train.py index 815a0a32..730d7389 100644 --- a/train.py +++ b/train.py @@ -81,6 +81,20 @@ def setup_loader(ap, is_val=False, verbose=False): return loader +def gradual_training_scheduler(global_step): + if global_step < 10000: + r, batch_size = 7, 32 + elif global_step < 50000: + r, batch_size = 5, 32 + elif global_step < 130000: + r, batch_size = 3, 32 + elif global_step < 290000: + r, batch_size = 2, 16 + else: + r, batch_size = 1, 16 + return r, batch_size + + def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, ap, epoch): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) @@ -524,7 +538,14 @@ def main(args): #pylint: disable=redefined-outer-name if 'best_loss' not in locals(): best_loss = float('inf') + current_step = 0 for epoch in range(0, c.epochs): + # set gradual training + r, c.batch_size = gradual_training_scheduler(current_step) + c.r = r + model.decoder._set_r(r) + print(" > Number of outputs per iteration:", model.decoder.r) + train_loss, current_step = train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, ap, epoch)