mirror of https://github.com/coqui-ai/TTS.git
gradual traning with memory queue
This commit is contained in:
parent
0470dd49f2
commit
d47ba5d310
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"run_name": "libritts-360",
|
||||
"run_description": "LibriTTS 360 clean with multi speaker embedding.",
|
||||
"run_description": "LibriTTS 360 gradual traning with memory queue.",
|
||||
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
|
@ -31,13 +31,13 @@
|
|||
|
||||
"reinit_layers": [],
|
||||
|
||||
"model": "Tacotron2", // one of the model in models/
|
||||
"model": "Tacotron", // one of the model in models/
|
||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"memory_size": 7, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
||||
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
|
||||
|
@ -52,9 +52,9 @@
|
|||
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
"batch_size": 24, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
"r": 1, // Number of frames to predict for step.
|
||||
"r": 7, // Number of frames to predict for step.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
||||
|
|
|
@ -270,12 +270,14 @@ class Decoder(nn.Module):
|
|||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||
TODO: arguments
|
||||
"""
|
||||
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
|
||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
||||
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||
trans_agent, forward_attn_mask, location_attn, separate_stopnet):
|
||||
trans_agent, forward_attn_mask, location_attn,
|
||||
separate_stopnet):
|
||||
super(Decoder, self).__init__()
|
||||
self.r = r
|
||||
self.in_features = in_features
|
||||
|
@ -291,17 +293,18 @@ class Decoder(nn.Module):
|
|||
out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = nn.GRUCell(in_features + 128, 256)
|
||||
self.attention_layer = Attention(attention_rnn_dim=256,
|
||||
embedding_dim=in_features,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_windowing,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask)
|
||||
self.attention_layer = Attention(
|
||||
attention_rnn_dim=256,
|
||||
embedding_dim=in_features,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_windowing,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask)
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
|
@ -324,6 +327,9 @@ class Decoder(nn.Module):
|
|||
self.proj_to_mel.weight,
|
||||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def _set_r(self, new_r):
|
||||
self.r = new_r
|
||||
|
||||
def _reshape_memory(self, memory):
|
||||
"""
|
||||
Reshape the spectrograms for given 'r'
|
||||
|
@ -371,8 +377,11 @@ class Decoder(nn.Module):
|
|||
# Prenet
|
||||
processed_memory = self.prenet(self.memory_input)
|
||||
# Attention RNN
|
||||
self.attention_rnn_hidden = self.attention_rnn(torch.cat((processed_memory, self.current_context_vec), -1), self.attention_rnn_hidden)
|
||||
self.current_context_vec = self.attention_layer(self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||||
self.attention_rnn_hidden = self.attention_rnn(
|
||||
torch.cat((processed_memory, self.current_context_vec), -1),
|
||||
self.attention_rnn_hidden)
|
||||
self.current_context_vec = self.attention_layer(
|
||||
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((self.attention_rnn_hidden, self.current_context_vec),
|
||||
|
@ -395,17 +404,14 @@ class Decoder(nn.Module):
|
|||
stop_token = self.stopnet(stopnet_input.detach())
|
||||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
output = output[:, : self.r * self.memory_dim]
|
||||
return output, stop_token, self.attention_layer.attention_weights
|
||||
|
||||
def _update_memory_queue(self, new_memory):
|
||||
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
|
||||
self.memory_input = torch.cat([
|
||||
self.memory_input[:, self.r * self.memory_dim:].clone(),
|
||||
new_memory
|
||||
],
|
||||
dim=-1)
|
||||
else:
|
||||
self.memory_input = new_memory
|
||||
self.memory_input = torch.cat([
|
||||
self.memory_input[:, self.r * self.memory_dim:].clone(), new_memory
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
def forward(self, inputs, memory, mask):
|
||||
"""
|
||||
|
|
|
@ -39,7 +39,7 @@ class Tacotron(nn.Module):
|
|||
self.last_linear = nn.Sequential(
|
||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||
nn.Sigmoid())
|
||||
|
||||
|
||||
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
|
||||
B = characters.size(0)
|
||||
mask = sequence_mask(text_lengths).to(characters.device)
|
||||
|
|
21
train.py
21
train.py
|
@ -81,6 +81,20 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
return loader
|
||||
|
||||
|
||||
def gradual_training_scheduler(global_step):
|
||||
if global_step < 10000:
|
||||
r, batch_size = 7, 32
|
||||
elif global_step < 50000:
|
||||
r, batch_size = 5, 32
|
||||
elif global_step < 130000:
|
||||
r, batch_size = 3, 32
|
||||
elif global_step < 290000:
|
||||
r, batch_size = 2, 16
|
||||
else:
|
||||
r, batch_size = 1, 16
|
||||
return r, batch_size
|
||||
|
||||
|
||||
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||
ap, epoch):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
|
@ -524,7 +538,14 @@ def main(args): #pylint: disable=redefined-outer-name
|
|||
if 'best_loss' not in locals():
|
||||
best_loss = float('inf')
|
||||
|
||||
current_step = 0
|
||||
for epoch in range(0, c.epochs):
|
||||
# set gradual training
|
||||
r, c.batch_size = gradual_training_scheduler(current_step)
|
||||
c.r = r
|
||||
model.decoder._set_r(r)
|
||||
print(" > Number of outputs per iteration:", model.decoder.r)
|
||||
|
||||
train_loss, current_step = train(model, criterion, criterion_st,
|
||||
optimizer, optimizer_st, scheduler,
|
||||
ap, epoch)
|
||||
|
|
Loading…
Reference in New Issue