update tacotron_config.py for checking `r` and the docstring

This commit is contained in:
Eren Gölge 2021-05-17 11:35:30 +02:00
parent 12722501bb
commit 34a42d379f
3 changed files with 121 additions and 4 deletions

View File

@ -24,10 +24,10 @@ class TacotronConfig(BaseTTSConfig):
Path to the wav file used at inference to set the speech style through GST. If `GST` is enabled and
this is not defined, the model uses a zero vector as an input. Defaults to None.
r (int):
Number of output frames that the decoder computed per iteration. Larger values makes training and inference
faster but reduces the quality of the output frames. This needs to be tuned considering your own needs.
Defaults to 1.
gradual_trainin (List[List]):
Initial number of output frames that the decoder computed per iteration. Larger values makes training and inference
faster but reduces the quality of the output frames. This must be equal to the largest `r` value used in
`gradual_training` schedule. Defaults to 1.
gradual_training (List[List]):
Parameters for the gradual training schedule. It is in the form `[[a, b, c], [d ,e ,f] ..]` where `a` is
the step number to start using the rest of the values, `b` is the `r` value and `c` is the batch size.
If sets None, no gradual training is used. Defaults to None.
@ -168,3 +168,8 @@ class TacotronConfig(BaseTTSConfig):
decoder_ssim_alpha: float = 0.25
postnet_ssim_alpha: float = 0.25
ga_alpha: float = 5.0
def check_values(self):
if self.gradual_training:
assert self.gradual_training[0][1] == self.r, f"[!] the first scheduled gradual training `r` must be equal to the model's `r` value. {self.gradual_training[0][1]} vs {self.r}"

View File

@ -0,0 +1,21 @@
#!/bin/bash
RUN_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
echo $RUN_DIR
# download LJSpeech dataset
wget http://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
# extract
tar -xjf LJSpeech-1.1.tar.bz2
# create train-val splits
shuf LJSpeech-1.1/metadata.csv > LJSpeech-1.1/metadata_shuf.csv
head -n 12000 LJSpeech-1.1/metadata_shuf.csv > LJSpeech-1.1/metadata_train.csv
tail -n 1100 LJSpeech-1.1/metadata_shuf.csv > LJSpeech-1.1/metadata_val.csv
mv LJSpeech-1.1 $RUN_DIR/
rm LJSpeech-1.1.tar.bz2
# compute dataset mean and variance for normalization
python TTS/bin/compute_statistics.py $RUN_DIR/tacotron2-DCA.json $RUN_DIR/scale_stats.npy --data_path $RUN_DIR/LJSpeech-1.1/wavs/
# training ....
# change the GPU id if needed
CUDA_VISIBLE_DEVICES="0" python TTS/bin/train_tacotron.py --config_path $RUN_DIR/tacotron2-DDC.json \
--output_path $RUN_DIR \
--coqpit.datasets.0.path $RUN_DIR/LJSpeech-1.1/ \
--coqpit.audio.stats_path $RUN_DIR/scale_stats.npy \

View File

@ -0,0 +1,91 @@
{
"datasets": [
{
"name": "ljspeech",
"path": "DEFINE THIS",
"meta_file_train": "metadata.csv",
"meta_file_val": null
}
],
"audio": {
"fft_size": 1024,
"win_length": 1024,
"hop_length": 256,
"frame_length_ms": null,
"frame_shift_ms": null,
"sample_rate": 22050,
"preemphasis": 0.0,
"ref_level_db": 20,
"do_trim_silence": true,
"trim_db": 60,
"power": 1.5,
"griffin_lim_iters": 60,
"num_mels": 80,
"mel_fmin": 50.0,
"mel_fmax": 7600.0,
"spec_gain": 1,
"signal_norm": true,
"min_level_db": -100,
"symmetric_norm": true,
"max_norm": 4.0,
"clip_norm": true,
"stats_path": "scale_stats.npy"
},
"gst":{
"gst_embedding_dim": 256,
"gst_num_heads": 4,
"gst_num_style_tokens": 10
},
"model": "Tacotron2",
"run_name": "ljspeech-dcattn",
"run_description": "tacotron2 with dynamic convolution attention.",
"batch_size": 64,
"eval_batch_size": 16,
"r": 2,
"mixed_precision": true,
"loss_masking": true,
"decoder_loss_alpha": 0.25,
"postnet_loss_alpha": 0.25,
"postnet_diff_spec_alpha": 0.25,
"decoder_diff_spec_alpha": 0.25,
"decoder_ssim_alpha": 0.25,
"postnet_ssim_alpha": 0.25,
"ga_alpha": 5.0,
"stopnet_pos_weight": 15.0,
"run_eval": true,
"test_delay_epochs": 10,
"test_sentences_file": null,
"noam_schedule": true,
"grad_clip": 0.05,
"epochs": 1000,
"lr": 0.001,
"wd": 1e-06,
"warmup_steps": 4000,
"memory_size": -1,
"prenet_type": "original",
"prenet_dropout": true,
"attention_type": "original",
"location_attn": true,
"double_decoder_consistency": true,
"ddc_r": 6,
"attention_norm": "sigmoid",
"gradual_training": [[0, 6, 64], [10000, 4, 32], [50000, 3, 32], [100000, 2, 32]],
"stopnet": true,
"separate_stopnet": true,
"print_step": 25,
"tb_plot_step": 100,
"print_eval": false,
"save_step": 10000,
"checkpoint": true,
"text_cleaner": "phoneme_cleaners",
"num_loader_workers": 4,
"num_val_loader_workers": 4,
"batch_group_size": 4,
"min_seq_len": 6,
"max_seq_len": 180,
"compute_input_seq_cache": true,
"output_path": "DEFINE THIS",
"phoneme_cache_path": "DEFINE THIS",
"use_phonemes": false,
"phoneme_language": "en-us"
}