gradual traning with memory queue

This commit is contained in:
Eren Golge 2019-07-20 12:33:21 +02:00
parent 0470dd49f2
commit d47ba5d310
4 changed files with 55 additions and 28 deletions

View File

@ -1,6 +1,6 @@
{
"run_name": "libritts-360",
"run_description": "LibriTTS 360 clean with multi speaker embedding.",
"run_description": "LibriTTS 360 gradual traning with memory queue.",
"audio":{
// Audio processing parameters
@ -31,13 +31,13 @@
"reinit_layers": [],
"model": "Tacotron2", // one of the model in models/
"model": "Tacotron", // one of the model in models/
"grad_clip": 1, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"memory_size": 7, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
@ -52,9 +52,9 @@
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"batch_size": 24, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16,
"r": 1, // Number of frames to predict for step.
"r": 7, // Number of frames to predict for step.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.

View File

@ -270,12 +270,14 @@ class Decoder(nn.Module):
memory_size (int): size of the past window. if <= 0 memory_size = r
TODO: arguments
"""
# Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
attn_norm, prenet_type, prenet_dropout, forward_attn,
trans_agent, forward_attn_mask, location_attn, separate_stopnet):
trans_agent, forward_attn_mask, location_attn,
separate_stopnet):
super(Decoder, self).__init__()
self.r = r
self.in_features = in_features
@ -291,17 +293,18 @@ class Decoder(nn.Module):
out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = nn.GRUCell(in_features + 128, 256)
self.attention_layer = Attention(attention_rnn_dim=256,
embedding_dim=in_features,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
attention_location_kernel_size=31,
windowing=attn_windowing,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask)
self.attention_layer = Attention(
attention_rnn_dim=256,
embedding_dim=in_features,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
attention_location_kernel_size=31,
windowing=attn_windowing,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state
@ -324,6 +327,9 @@ class Decoder(nn.Module):
self.proj_to_mel.weight,
gain=torch.nn.init.calculate_gain('linear'))
def _set_r(self, new_r):
self.r = new_r
def _reshape_memory(self, memory):
"""
Reshape the spectrograms for given 'r'
@ -371,8 +377,11 @@ class Decoder(nn.Module):
# Prenet
processed_memory = self.prenet(self.memory_input)
# Attention RNN
self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden)
self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
self.attention_rnn_hidden = self.attention_rnn(
torch.cat((processed_memory, self.current_context_vec), -1),
self.attention_rnn_hidden)
self.current_context_vec = self.attention_layer(
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((self.attention_rnn_hidden, self.current_context_vec),
@ -395,17 +404,14 @@ class Decoder(nn.Module):
stop_token = self.stopnet(stopnet_input.detach())
else:
stop_token = self.stopnet(stopnet_input)
output = output[:, : self.r * self.memory_dim]
return output, stop_token, self.attention_layer.attention_weights
def _update_memory_queue(self, new_memory):
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
self.memory_input = torch.cat([
self.memory_input[:, self.r * self.memory_dim:].clone(),
new_memory
],
dim=-1)
else:
self.memory_input = new_memory
self.memory_input = torch.cat([
self.memory_input[:, self.r * self.memory_dim:].clone(), new_memory
],
dim=-1)
def forward(self, inputs, memory, mask):
"""

View File

@ -39,7 +39,7 @@ class Tacotron(nn.Module):
self.last_linear = nn.Sequential(
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
nn.Sigmoid())
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
B = characters.size(0)
mask = sequence_mask(text_lengths).to(characters.device)

View File

@ -81,6 +81,20 @@ def setup_loader(ap, is_val=False, verbose=False):
return loader
def gradual_training_scheduler(global_step):
if global_step < 10000:
r, batch_size = 7, 32
elif global_step < 50000:
r, batch_size = 5, 32
elif global_step < 130000:
r, batch_size = 3, 32
elif global_step < 290000:
r, batch_size = 2, 16
else:
r, batch_size = 1, 16
return r, batch_size
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
ap, epoch):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
@ -524,7 +538,14 @@ def main(args): #pylint: disable=redefined-outer-name
if 'best_loss' not in locals():
best_loss = float('inf')
current_step = 0
for epoch in range(0, c.epochs):
# set gradual training
r, c.batch_size = gradual_training_scheduler(current_step)
c.r = r
model.decoder._set_r(r)
print(" > Number of outputs per iteration:", model.decoder.r)
train_loss, current_step = train(model, criterion, criterion_st,
optimizer, optimizer_st, scheduler,
ap, epoch)