diff --git a/.compute b/.compute index 5ef7df1d..d2d37fc1 100644 --- a/.compute +++ b/.compute @@ -1,12 +1,14 @@ #!/bin/bash -ls ${SHARED_DIR}/data/mozilla/Judy/ yes | apt-get install sox yes | apt-get install ffmpeg -soxi /data/ro/shared/data/mozilla/Judy/batch6/wavs_no_processing/6_126.wav -pip3 install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl yes | apt-get install espeak +yes | apt-get install tmux +yes | apt-get install zsh +pip3 install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl +# wget https://www.dropbox.com/s/m8waow6b3ydpf6h/MozillaDataset.tar.gz?dl=0 -O /data/rw/home/mozilla.tar +wget https://www.dropbox.com/s/wqn5v3wkktw9lmo/install.sh?dl=0 -O install.sh +sudo sh install.sh python3 setup.py develop -# wget https://www.dropbox.com/s/evaouukiwb7krz8/MozillaDataset.tar.gz?dl=0 -O ${USER_DIR}/MozillaDataset.tar.gz -# tar -xzvf ${USER_DIR}/MozillaDataset.tar.gz --no-same-owner -C ${USER_DIR} -# python3 distribute.py --config_path config_cluster.json --data_path ${USER_DIR}/MozillaDataset/Mozilla/ --restore_path ${USER_DIR}/best_model_4583.pth.tar -python3 distribute.py --config_path config_cluster.json --data_path ${SHARED_DIR}/data/mozilla/Judy/ +python3 distribute.py --config_path config_cluster.json --data_path ${USER_DIR}/MozillaAll2/Mozilla/ --restore_path ${USER_DIR}/checkpoint_123000_4761.pth.tar +# python3 distribute.py --config_path config_cluster.json --data_path ${SHARED_DIR}/data/mozilla/Judy/ +# while true; do sleep 1000000; done diff --git a/config_cluster.json b/config_cluster.json index d8b066d7..11c2415f 100644 --- a/config_cluster.json +++ b/config_cluster.json @@ -1,6 +1,6 @@ { - "run_name": "mozilla-nomask-fattn-bn", - "run_description": "Finetune 4702 orignal -> bn prenet - Mozilla with prenet bn, no mask, batch group size 0", + "run_name": "mozilla-fattn", + "run_description": "Finetune 4761 with BN + Dropout. It is to compare to 4780 and see how dropout behaves with BN.", "audio":{ // Audio processing parameters @@ -41,12 +41,14 @@ "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "prenet_type": "bn", // ONLY TACOTRON2 - "original" or "bn". - "use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. + "prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet. + "use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. "transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention. + "location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default. "loss_masking": false, // enable / disable loss masking against the sequence padding. "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. - "batch_size": 24, // Batch size for training. Lower values than 32 might cause hard to learn attention. + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. "eval_batch_size":16, "r": 1, // Number of frames to predict for step. "wd": 0.000001, // Weight decay weight. @@ -59,13 +61,13 @@ "run_eval": true, "test_delay_epochs": 1, //Until attention is aligned, testing only wastes computation time. "data_path": "/media/erogol/data_ssd/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument - "meta_file_train": "metadata.txt", // DATASET-RELATED: metafile for training dataloader. + "meta_file_train": "metadata_train.txt", // DATASET-RELATED: metafile for training dataloader. "meta_file_val": "metadata_val.txt", // DATASET-RELATED: metafile for evaluation dataloader. "dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py "min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training "max_seq_len": 150, // DATASET-RELATED: maximum text length "output_path": "../keep/", // DATASET-RELATED: output path for all training outputs. - "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_val_loader_workers": 4, // number of evaluation data loader processes. "phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder. "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. diff --git a/layers/tacotron2.py b/layers/tacotron2.py index df05e5ad..6bb08ce1 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -53,9 +53,10 @@ class LinearBN(nn.Module): class Prenet(nn.Module): - def __init__(self, in_features, prenet_type, out_features=[256, 256]): + def __init__(self, in_features, prenet_type, prenet_dropout, out_features=[256, 256]): super(Prenet, self).__init__() self.prenet_type = prenet_type + self.prenet_dropout = prenet_dropout in_features = [in_features] + out_features[:-1] if prenet_type == "bn": self.layers = nn.ModuleList([ @@ -70,9 +71,9 @@ class Prenet(nn.Module): def forward(self, x): for linear in self.layers: - if self.prenet_type == "original": + if self.prenet_dropout: x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) - elif self.prenet_type == "bn": + else: x = F.relu(linear(x)) return x @@ -120,7 +121,7 @@ class LocationLayer(nn.Module): class Attention(nn.Module): - def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, location_attention, attention_location_n_filters, attention_location_kernel_size, windowing, norm, forward_attn, trans_agent): super(Attention, self).__init__() @@ -130,38 +131,65 @@ class Attention(nn.Module): embedding_dim, attention_dim, bias=False, init_gain='tanh') self.v = Linear(attention_dim, 1, bias=True) if trans_agent: - self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True) - self.location_layer = LocationLayer(attention_location_n_filters, - attention_location_kernel_size, - attention_dim) + self.ta = nn.Linear(attention_rnn_dim + embedding_dim, 1, bias=True) + if location_attention: + self.location_layer = LocationLayer(attention_location_n_filters, + attention_location_kernel_size, + attention_dim) self._mask_value = -float("inf") self.windowing = windowing self.win_idx = None self.norm = norm self.forward_attn = forward_attn self.trans_agent = trans_agent + self.location_attention = location_attention def init_win_idx(self): self.win_idx = -1 self.win_back = 2 self.win_front = 6 - def init_forward_attn_state(self, inputs): - """ - Init forward attention states - """ + def init_forward_attn(self, inputs): B = inputs.shape[0] T = inputs.shape[1] self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7 ], dim=1).to(inputs.device) self.u = (0.5 * torch.ones([B, 1])).to(inputs.device) - def get_attention(self, query, processed_inputs, attention_cat): + def init_location_attention(self, inputs): + B = inputs.shape[0] + T = inputs.shape[1] + self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_()) + + def init_states(self, inputs): + B = inputs.shape[0] + T = inputs.shape[1] + self.attention_weights = Variable(inputs.data.new(B, T).zero_()) + if self.location_attention: + self.init_location_attention(inputs) + if self.forward_attn: + self.init_forward_attn(inputs) + if self.windowing: + self.init_win_idx() + + def update_location_attention(self, alignments): + self.attention_weights_cum += alignments + + def get_location_attention(self, query, processed_inputs): + attention_cat = torch.cat((self.attention_weights.unsqueeze(1), + self.attention_weights_cum.unsqueeze(1)), + dim=1) processed_query = self.query_layer(query.unsqueeze(1)) processed_attention_weights = self.location_layer(attention_cat) energies = self.v( torch.tanh(processed_query + processed_attention_weights + - processed_inputs)) + processed_inputs)) + energies = energies.squeeze(-1) + return energies, processed_query + def get_attention(self, query, processed_inputs): + processed_query = self.query_layer(query.unsqueeze(1)) + energies = self.v( + torch.tanh(processed_query +processed_inputs)) energies = energies.squeeze(-1) return energies, processed_query @@ -180,7 +208,7 @@ class Attention(nn.Module): self.win_idx = torch.argmax(attention, 1).long()[0].item() return attention - def apply_forward_attention(self, inputs, alignment, processed_query): + def apply_forward_attention(self, inputs, alignment, query): # forward attention prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device) alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment @@ -190,15 +218,18 @@ class Attention(nn.Module): context = context.squeeze(1) # compute transition agent if self.trans_agent: - ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1) + ta_input = torch.cat([context, query.squeeze(1)], dim=-1) self.u = torch.sigmoid(self.ta(ta_input)) - return context, self.alpha, alignment + return context, self.alpha def forward(self, attention_hidden_state, inputs, processed_inputs, - attention_cat, mask): - attention, processed_query = self.get_attention( - attention_hidden_state, processed_inputs, attention_cat) - + mask): + if self.location_attention: + attention, processed_query = self.get_location_attention( + attention_hidden_state, processed_inputs) + else: + attention, processed_query = self.get_attention( + attention_hidden_state, processed_inputs) # apply masking if mask is not None: attention.data.masked_fill_(1 - mask, self._mask_value) @@ -213,13 +244,16 @@ class Attention(nn.Module): attention).sum(dim=1).unsqueeze(1) else: raise RuntimeError("Unknown value for attention norm type") + if self.location_attention: + self.update_location_attention(alignment) # apply forward attention if enabled if self.forward_attn: - return self.apply_forward_attention(inputs, alignment, processed_query) + context, self.attention_weights = self.apply_forward_attention(inputs, alignment, attention_hidden_state) else: context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) - return context, alignment, alignment + self.attention_weights = alignment + return context class Postnet(nn.Module): @@ -289,7 +323,7 @@ class Encoder(nn.Module): # adapted from https://github.com/NVIDIA/tacotron2/ class Decoder(nn.Module): - def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent): + def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, location_attn): super(Decoder, self).__init__() self.mel_channels = inputs_dim self.r = r @@ -302,14 +336,14 @@ class Decoder(nn.Module): self.p_attention_dropout = 0.1 self.p_decoder_dropout = 0.1 - self.prenet = Prenet(self.mel_channels * r, prenet_type, + self.prenet = Prenet(self.mel_channels * r, prenet_type, prenet_dropout, [self.prenet_dim, self.prenet_dim]) self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.attention_rnn_dim) - self.attention_layer = Attention(self.attention_rnn_dim, in_features, - 128, 32, 31, attn_win, attn_norm, forward_attn, trans_agent) + self.attention_layer = Attention(self.attention_rnn_dim, in_features, 128, location_attn, + 32, 31, attn_win, attn_norm, forward_attn, trans_agent) self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, self.decoder_rnn_dim, 1) @@ -351,9 +385,6 @@ class Decoder(nn.Module): self.context = Variable( inputs.data.new(B, self.encoder_embedding_dim).zero_()) - - self.attention_weights = Variable(inputs.data.new(B, T).zero_()) - self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_()) self.inputs = inputs self.processed_inputs = self.attention_layer.inputs_layer(inputs) @@ -384,14 +415,10 @@ class Decoder(nn.Module): self.attention_cell = F.dropout( self.attention_cell, self.p_attention_dropout, self.training) - attention_cat = torch.cat((self.attention_weights.unsqueeze(1), - self.attention_weights_cum.unsqueeze(1)), - dim=1) - self.context, self.attention_weights, alignments = self.attention_layer( + self.context = self.attention_layer( self.attention_hidden, self.inputs, self.processed_inputs, - attention_cat, self.mask) + self.mask) - self.attention_weights_cum += alignments memory = torch.cat( (self.attention_hidden, self.context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( @@ -410,7 +437,7 @@ class Decoder(nn.Module): stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) gate_prediction = self.stopnet(stopnet_input) - return decoder_output, gate_prediction, self.attention_weights + return decoder_output, gate_prediction, self.attention_layer.attention_weights def forward(self, inputs, memories, mask): memory = self.get_go_frame(inputs).unsqueeze(0) @@ -419,8 +446,7 @@ class Decoder(nn.Module): memories = self.prenet(memories) self._init_states(inputs, mask=mask) - if self.attention_layer.forward_attn: - self.attention_layer.init_forward_attn_state(inputs) + self.attention_layer.init_states(inputs) outputs, stop_tokens, alignments = [], [], [] while len(outputs) < memories.size(0) - 1: @@ -441,8 +467,7 @@ class Decoder(nn.Module): self._init_states(inputs, mask=None) self.attention_layer.init_win_idx() - if self.attention_layer.forward_attn: - self.attention_layer.init_forward_attn_state(inputs) + self.attention_layer.init_states(inputs) outputs, stop_tokens, alignments, t = [], [], [], 0 stop_flags = [False, False, False] @@ -460,7 +485,7 @@ class Decoder(nn.Module): stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): stop_count += 1 - if stop_count > 2: + if stop_count > 5: break elif len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") @@ -485,8 +510,7 @@ class Decoder(nn.Module): self._init_states(inputs, mask=None, keep_states=True) self.attention_layer.init_win_idx() - if self.attention_layer.forward_attn: - self.attention_layer.init_forward_attn_state(inputs) + self.attention_layer.init_states(inputs) outputs, stop_tokens, alignments, t = [], [], [], 0 stop_flags = [False, False, False] stop_count = 0 diff --git a/models/tacotron2.py b/models/tacotron2.py index 2e7c857b..e9ce1a1b 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask # TODO: match function arguments with tacotron class Tacotron2(nn.Module): - def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False, trans_agent=False): + def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", prenet_dropout=True, forward_attn=False, trans_agent=False, location_attn=True): super(Tacotron2, self).__init__() self.n_mel_channels = 80 self.n_frames_per_step = r @@ -18,7 +18,7 @@ class Tacotron2(nn.Module): val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(512) - self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent) + self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, location_attn) self.postnet = Postnet(self.n_mel_channels) def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments): diff --git a/utils/generic_utils.py b/utils/generic_utils.py index f22c4a3a..902affba 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -262,6 +262,8 @@ def setup_model(num_chars, c): attn_win=c.windowing, attn_norm=c.attention_norm, prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent) + trans_agent=c.transition_agent, + location_attn=c.location_attn) return model \ No newline at end of file diff --git a/utils/visual.py b/utils/visual.py index b259bdd9..9fd7a790 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -30,14 +30,14 @@ def plot_spectrogram(linear_output, audio): return fig -def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CONFIG, spectrogram=None): +def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CONFIG, spectrogram=None, output_path=None): if spectrogram is not None: num_plot = 4 else: num_plot = 3 label_fontsize = 16 - plt.figure(figsize=(8, 24)) + fig = plt.figure(figsize=(8, 24)) plt.subplot(num_plot, 1, 1) plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None) @@ -46,6 +46,7 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON if CONFIG.use_phonemes: seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars) text = sequence_to_phoneme(seq) + print(text) plt.yticks(range(len(text)), list(text)) plt.colorbar() @@ -69,3 +70,8 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON plt.ylabel("Hz", fontsize=label_fontsize) plt.tight_layout() plt.colorbar() + + if output_path: + print(output_path) + fig.savefig(output_path) + plt.close()