mirror of https://github.com/coqui-ai/TTS.git
Update tacotron2 for gradual training and chnage the indexing of prenet nputs to oick the last frame
This commit is contained in:
parent
1f4ec804b6
commit
8d3775a7d6
2
.compute
2
.compute
|
@ -10,7 +10,7 @@ wget https://www.dropbox.com/s/wqn5v3wkktw9lmo/install.sh?dl=0 -O install.sh
|
||||||
sudo sh install.sh
|
sudo sh install.sh
|
||||||
python3 setup.py develop
|
python3 setup.py develop
|
||||||
# cp -R ${USER_DIR}/GermanData ../tmp/
|
# cp -R ${USER_DIR}/GermanData ../tmp/
|
||||||
# python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
|
python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
|
||||||
# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
|
# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
|
||||||
# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/
|
# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/
|
||||||
python3 distribute.py --config_path config.json --data_path /data/rw/home/LibriTTS/train-clean-360
|
python3 distribute.py --config_path config.json --data_path /data/rw/home/LibriTTS/train-clean-360
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"run_name": "ljspeech",
|
"run_name": "ljspeech",
|
||||||
"run_description": "gradual training with prenet frame size 1 + no maxout for cbhg + symmetric norm.",
|
"run_description": "Tacotron2",
|
||||||
|
|
||||||
"audio":{
|
"audio":{
|
||||||
// Audio processing parameters
|
// Audio processing parameters
|
||||||
|
@ -31,7 +31,7 @@
|
||||||
|
|
||||||
"reinit_layers": [],
|
"reinit_layers": [],
|
||||||
|
|
||||||
"model": "Tacotron", // one of the model in models/
|
"model": "Tacotron2", // one of the model in models/
|
||||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||||
"epochs": 1000, // total number of epochs to train.
|
"epochs": 1000, // total number of epochs to train.
|
||||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||||
|
@ -55,10 +55,10 @@
|
||||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||||
"eval_batch_size":16,
|
"eval_batch_size":16,
|
||||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||||
"gradual_training": [[0, 7, 32], [10000, 5, 32], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
|
"gradual_training": [[0, 7, 32], [1, 5, 32], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
|
||||||
"wd": 0.000001, // Weight decay weight.
|
"wd": 0.000001, // Weight decay weight.
|
||||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||||
"save_step": 10000, // Number of training steps expected to save traning stats and checkpoints.
|
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||||
"print_step": 25, // Number of steps to log traning on console.
|
"print_step": 25, // Number of steps to log traning on console.
|
||||||
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||||
|
|
||||||
|
|
|
@ -406,7 +406,7 @@ class Decoder(nn.Module):
|
||||||
self.memory_input = new_memory[:, :self.memory_size * self.memory_dim]
|
self.memory_input = new_memory[:, :self.memory_size * self.memory_dim]
|
||||||
else:
|
else:
|
||||||
# use only the last frame prediction
|
# use only the last frame prediction
|
||||||
self.memory_input = new_memory[:, :self.memory_dim]
|
self.memory_input = new_memory[:, self.memory_dim * (self.r - 1):]
|
||||||
|
|
||||||
def forward(self, inputs, memory, mask, speaker_embeddings=None):
|
def forward(self, inputs, memory, mask, speaker_embeddings=None):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -101,6 +101,7 @@ class Decoder(nn.Module):
|
||||||
forward_attn_mask, location_attn, separate_stopnet):
|
forward_attn_mask, location_attn, separate_stopnet):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.mel_channels = inputs_dim
|
self.mel_channels = inputs_dim
|
||||||
|
self.r_init = r
|
||||||
self.r = r
|
self.r = r
|
||||||
self.encoder_embedding_dim = in_features
|
self.encoder_embedding_dim = in_features
|
||||||
self.separate_stopnet = separate_stopnet
|
self.separate_stopnet = separate_stopnet
|
||||||
|
@ -111,8 +112,7 @@ class Decoder(nn.Module):
|
||||||
self.gate_threshold = 0.5
|
self.gate_threshold = 0.5
|
||||||
self.p_attention_dropout = 0.1
|
self.p_attention_dropout = 0.1
|
||||||
self.p_decoder_dropout = 0.1
|
self.p_decoder_dropout = 0.1
|
||||||
|
self.prenet = Prenet(self.mel_channels, prenet_type,
|
||||||
self.prenet = Prenet(self.mel_channels * r, prenet_type,
|
|
||||||
prenet_dropout,
|
prenet_dropout,
|
||||||
[self.prenet_dim, self.prenet_dim], bias=False)
|
[self.prenet_dim, self.prenet_dim], bias=False)
|
||||||
|
|
||||||
|
@ -135,44 +135,34 @@ class Decoder(nn.Module):
|
||||||
self.decoder_rnn_dim, 1)
|
self.decoder_rnn_dim, 1)
|
||||||
|
|
||||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
||||||
self.mel_channels * r)
|
self.mel_channels * self.r_init)
|
||||||
|
|
||||||
self.stopnet = nn.Sequential(
|
self.stopnet = nn.Sequential(
|
||||||
nn.Dropout(0.1),
|
nn.Dropout(0.1),
|
||||||
Linear(
|
Linear(
|
||||||
self.decoder_rnn_dim + self.mel_channels * r,
|
self.decoder_rnn_dim + self.mel_channels * self.r_init,
|
||||||
1,
|
1,
|
||||||
bias=True,
|
bias=True,
|
||||||
init_gain='sigmoid'))
|
init_gain='sigmoid'))
|
||||||
|
|
||||||
self.attention_rnn_init = nn.Embedding(1, self.query_dim)
|
|
||||||
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
|
||||||
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
|
||||||
self.memory_truncated = None
|
self.memory_truncated = None
|
||||||
|
|
||||||
|
def set_r(self, new_r):
|
||||||
|
self.r = new_r
|
||||||
|
|
||||||
def get_go_frame(self, inputs):
|
def get_go_frame(self, inputs):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
memory = self.go_frame_init(inputs.data.new_zeros(B).long())
|
memory = torch.zeros(B, self.mel_channels * self.r, device=inputs.device)
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
def _init_states(self, inputs, mask, keep_states=False):
|
def _init_states(self, inputs, mask, keep_states=False):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
# T = inputs.size(1)
|
# T = inputs.size(1)
|
||||||
|
|
||||||
if not keep_states:
|
if not keep_states:
|
||||||
self.query = self.attention_rnn_init(
|
self.query = torch.zeros(B, self.query_dim, device=inputs.device)
|
||||||
inputs.data.new_zeros(B).long())
|
self.attention_rnn_cell_state = torch.zeros(B, self.query_dim, device=inputs.device)
|
||||||
self.attention_rnn_cell_state = Variable(
|
self.decoder_hidden = torch.zeros(B, self.decoder_rnn_dim, device=inputs.device)
|
||||||
inputs.data.new(B, self.query_dim).zero_())
|
self.decoder_cell = torch.zeros(B, self.decoder_rnn_dim, device=inputs.device)
|
||||||
|
self.context = torch.zeros(B, self.encoder_embedding_dim, device=inputs.device)
|
||||||
self.decoder_hidden = self.decoder_rnn_inits(
|
|
||||||
inputs.data.new_zeros(B).long())
|
|
||||||
self.decoder_cell = Variable(
|
|
||||||
inputs.data.new(B, self.decoder_rnn_dim).zero_())
|
|
||||||
|
|
||||||
self.context = Variable(
|
|
||||||
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
|
||||||
|
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.processed_inputs = self.attention.inputs_layer(inputs)
|
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
|
@ -192,6 +182,9 @@ class Decoder(nn.Module):
|
||||||
outputs = outputs.transpose(1, 2)
|
outputs = outputs.transpose(1, 2)
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
|
def _update_memory(self, memory):
|
||||||
|
return memory[:, :, self.mel_channels * (self.r - 1) :]
|
||||||
|
|
||||||
def decode(self, memory):
|
def decode(self, memory):
|
||||||
query_input = torch.cat((memory, self.context), -1)
|
query_input = torch.cat((memory, self.context), -1)
|
||||||
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
||||||
|
@ -223,13 +216,14 @@ class Decoder(nn.Module):
|
||||||
stop_token = self.stopnet(stopnet_input.detach())
|
stop_token = self.stopnet(stopnet_input.detach())
|
||||||
else:
|
else:
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
|
decoder_output = decoder_output[:, :self.r * self.mel_channels]
|
||||||
return decoder_output, stop_token, self.attention.attention_weights
|
return decoder_output, stop_token, self.attention.attention_weights
|
||||||
|
|
||||||
def forward(self, inputs, memories, mask):
|
def forward(self, inputs, memories, mask):
|
||||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||||
memories = self._reshape_memory(memories)
|
memories = self._reshape_memory(memories)
|
||||||
memories = torch.cat((memory, memories), dim=0)
|
memories = torch.cat((memory, memories), dim=0)
|
||||||
memories = self.prenet(memories)
|
memories = self.prenet(self._update_memory(memories))
|
||||||
|
|
||||||
self._init_states(inputs, mask=mask)
|
self._init_states(inputs, mask=mask)
|
||||||
self.attention.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
|
@ -277,7 +271,7 @@ class Decoder(nn.Module):
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
break
|
break
|
||||||
|
|
||||||
memory = mel_output
|
memory = self._update_memory(mel_output)
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||||
|
|
2
train.py
2
train.py
|
@ -62,7 +62,7 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
dataset = MyDataset(
|
dataset = MyDataset(
|
||||||
c.r,
|
c.r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
meta_data=meta_data_eval if is_val else meta_data_train[:64],
|
||||||
ap=ap,
|
ap=ap,
|
||||||
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
|
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
|
||||||
min_seq_len=c.min_seq_len,
|
min_seq_len=c.min_seq_len,
|
||||||
|
|
Loading…
Reference in New Issue