Update tacotron2 for gradual training and chnage the indexing of prenet nputs to oick the last frame

This commit is contained in:
Eren Golge 2019-09-17 18:43:11 +02:00
parent 1f4ec804b6
commit 8d3775a7d6
5 changed files with 26 additions and 32 deletions

View File

@ -10,7 +10,7 @@ wget https://www.dropbox.com/s/wqn5v3wkktw9lmo/install.sh?dl=0 -O install.sh
sudo sh install.sh
python3 setup.py develop
# cp -R ${USER_DIR}/GermanData ../tmp/
# python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/
python3 distribute.py --config_path config.json --data_path /data/rw/home/LibriTTS/train-clean-360

View File

@ -1,6 +1,6 @@
{
"run_name": "ljspeech",
"run_description": "gradual training with prenet frame size 1 + no maxout for cbhg + symmetric norm.",
"run_description": "Tacotron2",
"audio":{
// Audio processing parameters
@ -31,7 +31,7 @@
"reinit_layers": [],
"model": "Tacotron", // one of the model in models/
"model": "Tacotron2", // one of the model in models/
"grad_clip": 1, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
@ -55,10 +55,10 @@
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":16,
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"gradual_training": [[0, 7, 32], [10000, 5, 32], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
"gradual_training": [[0, 7, 32], [1, 5, 32], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // ONLY TACOTRON - set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 10000, // Number of training steps expected to save traning stats and checkpoints.
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
"print_step": 25, // Number of steps to log traning on console.
"batch_group_size": 0, //Number of batches to shuffle after bucketing.

View File

@ -406,7 +406,7 @@ class Decoder(nn.Module):
self.memory_input = new_memory[:, :self.memory_size * self.memory_dim]
else:
# use only the last frame prediction
self.memory_input = new_memory[:, :self.memory_dim]
self.memory_input = new_memory[:, self.memory_dim * (self.r - 1):]
def forward(self, inputs, memory, mask, speaker_embeddings=None):
"""

View File

@ -101,6 +101,7 @@ class Decoder(nn.Module):
forward_attn_mask, location_attn, separate_stopnet):
super(Decoder, self).__init__()
self.mel_channels = inputs_dim
self.r_init = r
self.r = r
self.encoder_embedding_dim = in_features
self.separate_stopnet = separate_stopnet
@ -111,8 +112,7 @@ class Decoder(nn.Module):
self.gate_threshold = 0.5
self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1
self.prenet = Prenet(self.mel_channels * r, prenet_type,
self.prenet = Prenet(self.mel_channels, prenet_type,
prenet_dropout,
[self.prenet_dim, self.prenet_dim], bias=False)
@ -135,44 +135,34 @@ class Decoder(nn.Module):
self.decoder_rnn_dim, 1)
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
self.mel_channels * r)
self.mel_channels * self.r_init)
self.stopnet = nn.Sequential(
nn.Dropout(0.1),
Linear(
self.decoder_rnn_dim + self.mel_channels * r,
self.decoder_rnn_dim + self.mel_channels * self.r_init,
1,
bias=True,
init_gain='sigmoid'))
self.attention_rnn_init = nn.Embedding(1, self.query_dim)
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
self.memory_truncated = None
def set_r(self, new_r):
self.r = new_r
def get_go_frame(self, inputs):
B = inputs.size(0)
memory = self.go_frame_init(inputs.data.new_zeros(B).long())
memory = torch.zeros(B, self.mel_channels * self.r, device=inputs.device)
return memory
def _init_states(self, inputs, mask, keep_states=False):
B = inputs.size(0)
# T = inputs.size(1)
if not keep_states:
self.query = self.attention_rnn_init(
inputs.data.new_zeros(B).long())
self.attention_rnn_cell_state = Variable(
inputs.data.new(B, self.query_dim).zero_())
self.decoder_hidden = self.decoder_rnn_inits(
inputs.data.new_zeros(B).long())
self.decoder_cell = Variable(
inputs.data.new(B, self.decoder_rnn_dim).zero_())
self.context = Variable(
inputs.data.new(B, self.encoder_embedding_dim).zero_())
self.query = torch.zeros(B, self.query_dim, device=inputs.device)
self.attention_rnn_cell_state = torch.zeros(B, self.query_dim, device=inputs.device)
self.decoder_hidden = torch.zeros(B, self.decoder_rnn_dim, device=inputs.device)
self.decoder_cell = torch.zeros(B, self.decoder_rnn_dim, device=inputs.device)
self.context = torch.zeros(B, self.encoder_embedding_dim, device=inputs.device)
self.inputs = inputs
self.processed_inputs = self.attention.inputs_layer(inputs)
self.mask = mask
@ -192,6 +182,9 @@ class Decoder(nn.Module):
outputs = outputs.transpose(1, 2)
return outputs, stop_tokens, alignments
def _update_memory(self, memory):
return memory[:, :, self.mel_channels * (self.r - 1) :]
def decode(self, memory):
query_input = torch.cat((memory, self.context), -1)
self.query, self.attention_rnn_cell_state = self.attention_rnn(
@ -223,13 +216,14 @@ class Decoder(nn.Module):
stop_token = self.stopnet(stopnet_input.detach())
else:
stop_token = self.stopnet(stopnet_input)
decoder_output = decoder_output[:, :self.r * self.mel_channels]
return decoder_output, stop_token, self.attention.attention_weights
def forward(self, inputs, memories, mask):
memory = self.get_go_frame(inputs).unsqueeze(0)
memories = self._reshape_memory(memories)
memories = torch.cat((memory, memories), dim=0)
memories = self.prenet(memories)
memories = self.prenet(self._update_memory(memories))
self._init_states(inputs, mask=mask)
self.attention.init_states(inputs)
@ -277,7 +271,7 @@ class Decoder(nn.Module):
print(" | > Decoder stopped with 'max_decoder_steps")
break
memory = mel_output
memory = self._update_memory(mel_output)
t += 1
outputs, stop_tokens, alignments = self._parse_outputs(

View File

@ -62,7 +62,7 @@ def setup_loader(ap, is_val=False, verbose=False):
dataset = MyDataset(
c.r,
c.text_cleaner,
meta_data=meta_data_eval if is_val else meta_data_train,
meta_data=meta_data_eval if is_val else meta_data_train[:64],
ap=ap,
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
min_seq_len=c.min_seq_len,