mirror of https://github.com/coqui-ai/TTS.git
gradual traning with memory queue
This commit is contained in:
parent
0470dd49f2
commit
d47ba5d310
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"run_name": "libritts-360",
|
"run_name": "libritts-360",
|
||||||
"run_description": "LibriTTS 360 clean with multi speaker embedding.",
|
"run_description": "LibriTTS 360 gradual traning with memory queue.",
|
||||||
|
|
||||||
"audio":{
|
"audio":{
|
||||||
// Audio processing parameters
|
// Audio processing parameters
|
||||||
|
@ -31,13 +31,13 @@
|
||||||
|
|
||||||
"reinit_layers": [],
|
"reinit_layers": [],
|
||||||
|
|
||||||
"model": "Tacotron2", // one of the model in models/
|
"model": "Tacotron", // one of the model in models/
|
||||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||||
"epochs": 1000, // total number of epochs to train.
|
"epochs": 1000, // total number of epochs to train.
|
||||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
"memory_size": 7, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||||
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||||
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
||||||
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
|
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
|
||||||
|
@ -52,9 +52,9 @@
|
||||||
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||||
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||||
|
|
||||||
"batch_size": 24, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||||
"eval_batch_size":16,
|
"eval_batch_size":16,
|
||||||
"r": 1, // Number of frames to predict for step.
|
"r": 7, // Number of frames to predict for step.
|
||||||
"wd": 0.000001, // Weight decay weight.
|
"wd": 0.000001, // Weight decay weight.
|
||||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||||
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
||||||
|
|
|
@ -270,12 +270,14 @@ class Decoder(nn.Module):
|
||||||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||||
TODO: arguments
|
TODO: arguments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Pylint gets confused by PyTorch conventions here
|
# Pylint gets confused by PyTorch conventions here
|
||||||
#pylint: disable=attribute-defined-outside-init
|
#pylint: disable=attribute-defined-outside-init
|
||||||
|
|
||||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
||||||
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||||
trans_agent, forward_attn_mask, location_attn, separate_stopnet):
|
trans_agent, forward_attn_mask, location_attn,
|
||||||
|
separate_stopnet):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
|
@ -291,7 +293,8 @@ class Decoder(nn.Module):
|
||||||
out_features=[256, 128])
|
out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.attention_rnn = nn.GRUCell(in_features + 128, 256)
|
self.attention_rnn = nn.GRUCell(in_features + 128, 256)
|
||||||
self.attention_layer = Attention(attention_rnn_dim=256,
|
self.attention_layer = Attention(
|
||||||
|
attention_rnn_dim=256,
|
||||||
embedding_dim=in_features,
|
embedding_dim=in_features,
|
||||||
attention_dim=128,
|
attention_dim=128,
|
||||||
location_attention=location_attn,
|
location_attention=location_attn,
|
||||||
|
@ -324,6 +327,9 @@ class Decoder(nn.Module):
|
||||||
self.proj_to_mel.weight,
|
self.proj_to_mel.weight,
|
||||||
gain=torch.nn.init.calculate_gain('linear'))
|
gain=torch.nn.init.calculate_gain('linear'))
|
||||||
|
|
||||||
|
def _set_r(self, new_r):
|
||||||
|
self.r = new_r
|
||||||
|
|
||||||
def _reshape_memory(self, memory):
|
def _reshape_memory(self, memory):
|
||||||
"""
|
"""
|
||||||
Reshape the spectrograms for given 'r'
|
Reshape the spectrograms for given 'r'
|
||||||
|
@ -371,8 +377,11 @@ class Decoder(nn.Module):
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(self.memory_input)
|
processed_memory = self.prenet(self.memory_input)
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden)
|
self.attention_rnn_hidden = self.attention_rnn(
|
||||||
self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
torch.cat((processed_memory, self.current_context_vec), -1),
|
||||||
|
self.attention_rnn_hidden)
|
||||||
|
self.current_context_vec = self.attention_layer(
|
||||||
|
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((self.attention_rnn_hidden, self.current_context_vec),
|
torch.cat((self.attention_rnn_hidden, self.current_context_vec),
|
||||||
|
@ -395,17 +404,14 @@ class Decoder(nn.Module):
|
||||||
stop_token = self.stopnet(stopnet_input.detach())
|
stop_token = self.stopnet(stopnet_input.detach())
|
||||||
else:
|
else:
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
|
output = output[:, : self.r * self.memory_dim]
|
||||||
return output, stop_token, self.attention_layer.attention_weights
|
return output, stop_token, self.attention_layer.attention_weights
|
||||||
|
|
||||||
def _update_memory_queue(self, new_memory):
|
def _update_memory_queue(self, new_memory):
|
||||||
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
|
|
||||||
self.memory_input = torch.cat([
|
self.memory_input = torch.cat([
|
||||||
self.memory_input[:, self.r * self.memory_dim:].clone(),
|
self.memory_input[:, self.r * self.memory_dim:].clone(), new_memory
|
||||||
new_memory
|
|
||||||
],
|
],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
else:
|
|
||||||
self.memory_input = new_memory
|
|
||||||
|
|
||||||
def forward(self, inputs, memory, mask):
|
def forward(self, inputs, memory, mask):
|
||||||
"""
|
"""
|
||||||
|
|
21
train.py
21
train.py
|
@ -81,6 +81,20 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
def gradual_training_scheduler(global_step):
|
||||||
|
if global_step < 10000:
|
||||||
|
r, batch_size = 7, 32
|
||||||
|
elif global_step < 50000:
|
||||||
|
r, batch_size = 5, 32
|
||||||
|
elif global_step < 130000:
|
||||||
|
r, batch_size = 3, 32
|
||||||
|
elif global_step < 290000:
|
||||||
|
r, batch_size = 2, 16
|
||||||
|
else:
|
||||||
|
r, batch_size = 1, 16
|
||||||
|
return r, batch_size
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
ap, epoch):
|
ap, epoch):
|
||||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||||
|
@ -524,7 +538,14 @@ def main(args): #pylint: disable=redefined-outer-name
|
||||||
if 'best_loss' not in locals():
|
if 'best_loss' not in locals():
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
|
current_step = 0
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
|
# set gradual training
|
||||||
|
r, c.batch_size = gradual_training_scheduler(current_step)
|
||||||
|
c.r = r
|
||||||
|
model.decoder._set_r(r)
|
||||||
|
print(" > Number of outputs per iteration:", model.decoder.r)
|
||||||
|
|
||||||
train_loss, current_step = train(model, criterion, criterion_st,
|
train_loss, current_step = train(model, criterion, criterion_st,
|
||||||
optimizer, optimizer_st, scheduler,
|
optimizer, optimizer_st, scheduler,
|
||||||
ap, epoch)
|
ap, epoch)
|
||||||
|
|
Loading…
Reference in New Issue