Merge branch 'tacotron-gst' into dev

This commit is contained in:
Eren Golge 2019-07-11 15:32:32 +02:00
commit 5851c5d29b
23 changed files with 1068 additions and 341 deletions

View File

@ -10,7 +10,7 @@ wget https://www.dropbox.com/s/wqn5v3wkktw9lmo/install.sh?dl=0 -O install.sh
sudo sh install.sh
python3 setup.py develop
# cp -R ${USER_DIR}/GermanData ../tmp/
# python3 distribute.py --config_path config_tacotron_de.json --data_path ../tmp/GermanData/karlsson/
cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
python3 distribute.py --config_path config_tacotron.json --data_path ../tmp/Mozilla_22050/ --restore_path /data/rw/home/4845.pth.tar
python3 distribute.py --config_path config_tacotron_de.json --data_path /data/rw/home/de_DE/
# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/
while true; do sleep 1000000; done

1
.dockerignore Normal file
View File

@ -0,0 +1 @@
.git/

View File

@ -1,23 +1,17 @@
FROM nvidia/cuda:9.0-base-ubuntu16.04 as base
FROM pytorch/pytorch:1.0.1-cuda10.0-cudnn7-runtime
WORKDIR /srv/app
RUN apt-get update && \
apt-get install -y git software-properties-common wget vim build-essential libsndfile1 && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y python3.6 python3.6-dev python3.6-tk && \
# Install pip manually
wget https://bootstrap.pypa.io/get-pip.py && \
python3.6 get-pip.py && \
rm get-pip.py && \
# Used by the server in server/synthesizer.py
pip install soundfile
apt-get install -y libsndfile1 espeak && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
ADD . /srv/app
# Copy Source later to enable dependency caching
COPY requirements.txt /srv/app/
RUN pip install -r requirements.txt
# Setup for development
RUN python3.6 setup.py develop
COPY . /srv/app
# http://bugs.python.org/issue19846
# > At the moment, setting "LANG=C" on a Linux system *fundamentally breaks Python 3*, and that's not OK.

View File

@ -75,6 +75,7 @@
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners"
"text_cleaner": "phoneme_cleaners",
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
}

82
config_libritts.json Normal file
View File

@ -0,0 +1,82 @@
{
"run_name": "libritts-360",
"run_description": "LibriTTS 360 clean with multi speaker embedding.",
"audio":{
// Audio processing parameters
"num_mels": 80, // size of the mel spec frame.
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 24000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"frame_length_ms": 50, // stft window length in ms.
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
// Normalization parameters
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": false, // move normalization to range [-1, 1]
"max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
},
"distributed":{
"backend": "nccl",
"url": "tcp:\/\/localhost:54321"
},
"reinit_layers": [],
"model": "Tacotron", // one of the model in models/
"grad_clip": 1, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"forward_attn_mask": false,
"transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"stopnet": true, // Train stopnet predicting the end of synthesis.
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16,
"r": 5, // Number of frames to predict for step.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
"print_step": 10, // Number of steps to log traning on console.
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
"run_eval": true,
"test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time.
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
"data_path": "/home/erogol/Data/Libri-TTS/train-clean-360/", // DATASET-RELATED: can overwritten from command argument
"meta_file_train": null, // DATASET-RELATED: metafile for training dataloader.
"meta_file_val": null, // DATASET-RELATED: metafile for evaluation dataloader.
"dataset": "libri_tts", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 150, // DATASET-RELATED: maximum text length
"output_path": "/media/erogol/data_ssd/Models/libri_tts/", // DATASET-RELATED: output path for all training outputs.
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 4, // number of evaluation data loader processes.
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners",
"use_speaker_embedding": true
}

View File

@ -1,6 +1,6 @@
{
"run_name": "mozilla-tacotron-tagent-bn",
"run_description": "finetune 4845 with bn prenet.",
"run_description": "compare the attention with gst model which does not align with the same config",
"audio":{
// Audio processing parameters
@ -40,11 +40,12 @@
"windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "bn", // ONLY TACOTRON2 - "original" or "bn".
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"transition_agent": true, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"prenet_type": "original", // "original" or "bn".
"prenet_dropout": true, // enable/disable dropout at prenet.
"use_forward_attn": true, // if it uses forward attention. In general, it aligns faster.
"forward_attn_mask": false, // Apply forward attention mask af inference to prevent bad modes. Try it if your model does not align well.
"transition_agent": true, // enable/disable transition agent of forward attention.
"location_attn": false, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"stopnet": true, // Train stopnet predicting the end of synthesis.
@ -75,6 +76,7 @@
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners"
"text_cleaner": "phoneme_cleaners",
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
}

View File

@ -42,6 +42,7 @@
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"forward_attn_mask": false, // Apply forward attention mask af inference to prevent bad modes. Try it if your model does not align well.
"transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
@ -77,6 +78,7 @@
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners"
"text_cleaner": "phoneme_cleaners",
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
}

View File

@ -1,85 +1,82 @@
{
"run_name": "german-tacotron-tagent-bn",
"run_description": "train german",
"audio":{
// Audio processing parameters
"num_mels": 80, // size of the mel spec frame.
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"frame_length_ms": 50, // stft window length in ms.
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
// Normalization parameters
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": false, // move normalization to range [-1, 1]
"max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
},
"distributed":{
"backend": "nccl",
"url": "tcp:\/\/localhost:54321"
},
"reinit_layers": [],
"model": "Tacotron", // one of the model in models/
"grad_clip": 1, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"stopnet": true, // Train stopnet predicting the end of synthesis.
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16,
"r": 5, // Number of frames to predict for step.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
"print_step": 10, // Number of steps to log traning on console.
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
"run_eval": false,
"test_sentences_file": "de_sentences.txt", // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
"test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time.
"data_path": "/media/erogol/data_ssd/Data/Mozilla/", // DATASET-RELATED: can overwritten from command argument
"meta_file_train": [
"grune_haus/metadata.csv",
"kleine_lord/metadata.csv",
"toten_seelen/metadata.csv",
"werde_die_du_bist/metadata.csv"
], // DATASET-RELATED: metafile for training dataloader.
"meta_file_val": "metadata_val.txt", // DATASET-RELATED: metafile for evaluation dataloader.
"dataset": "mailabs", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
"min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 200, // DATASET-RELATED: maximum text length
"output_path": "/media/erogol/data_ssd/Data/models/german/", // DATASET-RELATED: output path for all training outputs.
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 4, // number of evaluation data loader processes.
"phoneme_cache_path": "phoneme_cache", // phoneme computation is slow, therefore, it caches results in the given folder.
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "de", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners"
}
"run_name": "german-all-tacotrongst",
"run_description": "train with all the german dataset using gst",
"audio":{
// Audio processing parameters
"num_mels": 80, // size of the mel spec frame.
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"frame_length_ms": 50, // stft window length in ms.
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
// Normalization parameters
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": false, // move normalization to range [-1, 1]
"max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
},
"distributed":{
"backend": "nccl",
"url": "tcp:\/\/localhost:54321"
},
"reinit_layers": [],
"model": "TacotronGST", // one of the model in models/
"grad_clip": 1, // upper limit for gradients for clipping.
"epochs": 10000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
"forward_attn_mask": false,
"location_attn": true, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"stopnet": true, // Train stopnet predicting the end of synthesis.
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":32,
"r": 5, // Number of frames to predict for step.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
"print_step": 10, // Number of steps to log traning on console.
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
"run_eval": false,
"test_sentences_file": "de_sentences.txt", // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
"test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time.
"data_path": "/home/erogol/Data/m-ai-labs/de_DE/by_book/" , // DATASET-RELATED: can overwritten from command argument
"meta_file_train": null, // DATASET-RELATED: metafile for training dataloader.
"meta_file_val": null, // DATASET-RELATED: metafile for evaluation dataloader.
"dataset": "mailabs", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
"min_seq_len": 15, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 200, // DATASET-RELATED: maximum text length
"output_path": "/home/erogol/Models/mozilla_models/", // DATASET-RELATED: output path for all training outputs.
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 4, // number of evaluation data loader processes.
"phoneme_cache_path": "phoneme_cache", // phoneme computation is slow, therefore, it caches results in the given folder.
"use_phonemes": false, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "de", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners",
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
}

82
config_tacotron_gst.json Normal file
View File

@ -0,0 +1,82 @@
{
"run_name": "mozilla-tacotron-gst",
"run_description": "GST with single speaker",
"audio":{
// Audio processing parameters
"num_mels": 80, // size of the mel spec frame.
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"frame_length_ms": 50, // stft window length in ms.
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
// Normalization parameters
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": false, // move normalization to range [-1, 1]
"max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
},
"distributed":{
"backend": "nccl",
"url": "tcp:\/\/localhost:54321"
},
"reinit_layers": [],
"model": "TacotronGST", // one of the model in models/
"grad_clip": 1, // upper limit for gradients for clipping.
"epochs": 10000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // "original" or "bn".
"prenet_dropout": true, // enable/disable dropout at prenet.
"use_forward_attn": true, // if it uses forward attention. In general, it aligns faster.
"forward_attn_mask": false, // Apply forward attention mask af inference to prevent bad modes. Try it if your model does not align well.
"transition_agent": false, // enable/disable transition agent of forward attention.
"location_attn": false, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"stopnet": true, // Train stopnet predicting the end of synthesis.
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16,
"r": 5, // Number of frames to predict for step.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
"print_step": 10, // Number of steps to log traning on console.
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
"run_eval": true,
"test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time.
"test_sentences_file": null,
"data_path": "/media/erogol/data_ssd/Data/Mozilla/", // DATASET-RELATED: can overwritten from command argument
"meta_file_train": "metadata_train.txt", // DATASET-RELATED: metafile for training dataloader.
"meta_file_val": "metadata_val.txt", // DATASET-RELATED: metafile for evaluation dataloader.
"dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
"min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 150, // DATASET-RELATED: maximum text length
"output_path": "../keep/", // DATASET-RELATED: output path for all training outputs.
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 4, // number of evaluation data loader processes.
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners",
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
}

View File

@ -37,6 +37,8 @@ class MyDataset(Dataset):
ap (TTS.utils.AudioProcessor): audio processor object.
preprocessor (dataset.preprocess.Class): preprocessor for the dataset.
Create your own if you need to run a new dataset.
speaker_id_cache_path (str): path where the speaker name to id
mapping is stored
batch_group_size (int): (0) range of batch randomization after sorting
sequences by length.
min_seq_len (int): (0) minimum sequence length to be processed
@ -105,7 +107,7 @@ class MyDataset(Dataset):
return text
def load_data(self, idx):
text, wav_file = self.items[idx]
text, wav_file, speaker_name = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
if self.use_phonemes:
@ -120,7 +122,8 @@ class MyDataset(Dataset):
sample = {
'text': text,
'wav': wav,
'item_idx': self.items[idx][1]
'item_idx': self.items[idx][1],
'speaker_name': speaker_name
}
return sample
@ -182,6 +185,8 @@ class MyDataset(Dataset):
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
]
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
speaker_name = [batch[idx]['speaker_name']
for idx in ids_sorted_decreasing]
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
@ -219,7 +224,8 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
stop_targets, item_idxs
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0]))))

View File

@ -1,5 +1,13 @@
import os
from glob import glob
import re
import sys
def get_preprocessor_by_name(name):
"""Returns the respective preprocessing function."""
thismodule = sys.modules[__name__]
return getattr(thismodule, name.lower())
def tweb(root_path, meta_file):
@ -8,12 +16,13 @@ def tweb(root_path, meta_file):
"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "tweb"
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('\t')
wav_file = os.path.join(root_path, cols[0] + '.wav')
text = cols[1]
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items
@ -34,6 +43,7 @@ def mozilla_old(root_path, meta_file):
"""Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla_old"
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('|')
@ -41,7 +51,7 @@ def mozilla_old(root_path, meta_file):
wav_folder = "batch{}".format(batch_no)
wav_file = os.path.join(root_path, wav_folder, "wavs_no_processing", cols[1].strip())
text = cols[0].strip()
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items
@ -49,37 +59,46 @@ def mozilla(root_path, meta_file):
"""Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla"
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('|')
wav_file = cols[1].strip()
text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file)
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items
def mailabs(root_path, meta_files):
"""Normalizes M-AI-Labs meta data files to TTS format"""
speaker_regex = re.compile("by_book/(male|female|mix)/(?P<speaker_name>[^/]+)/")
if meta_files is None:
meta_files = glob(root_path+"/**/metadata.csv", recursive=True)
folders = [os.path.dirname(f.strip()) for f in meta_files]
csv_files = glob(root_path+"/**/metadata.csv", recursive=True)
folders = [os.path.dirname(f) for f in csv_files]
else:
csv_files = meta_files
folders = [f.strip().split("by_book")[1][1:] for f in csv_files]
# meta_files = [f.strip() for f in meta_files.split(",")]
items = []
for idx, meta_file in enumerate(meta_files):
print(" | > {}".format(meta_file))
for idx, csv_file in enumerate(csv_files):
# determine speaker based on folder structure...
speaker_name = speaker_regex.search(csv_file).group("speaker_name")
print(" | > {}".format(csv_file))
folder = folders[idx]
txt_file = os.path.join(root_path, meta_file)
txt_file = os.path.join(root_path, csv_file)
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('|')
wav_file = os.path.join(root_path, folder, 'wavs',
cols[0] + '.wav')
if os.path.isfile(wav_file):
text = cols[1]
items.append([text, wav_file])
if meta_files is None:
wav_file = os.path.join(folder, 'wavs', cols[0] + '.wav')
else:
continue
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), 'wavs', cols[0] + '.wav')
if os.path.isfile(wav_file):
text = cols[1].strip()
items.append([text, wav_file, speaker_name])
else:
raise RuntimeError("> File %s is not exist!"%(wav_file))
return items
@ -87,12 +106,13 @@ def ljspeech(root_path, meta_file):
"""Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('|')
wav_file = os.path.join(root_path, 'wavs', cols[0] + '.wav')
text = cols[1]
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items
@ -100,12 +120,13 @@ def nancy(root_path, meta_file):
"""Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "nancy"
with open(txt_file, 'r') as ttf:
for line in ttf:
id = line.split()[1]
text = line[line.find('"') + 1:line.rfind('"') - 1]
wav_file = os.path.join(root_path, "wavn", id + ".wav")
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items
@ -119,6 +140,28 @@ def common_voice(root_path, meta_file):
continue
cols = line.split("\t")
text = cols[2]
speaker_name = cols[0]
wav_file = os.path.join(root_path, "clips", cols[1] + ".wav")
items.append([text, wav_file])
items.append([text, wav_file, speaker_name])
return items
def libri_tts(root_path, meta_files=None):
"""https://ai.google/tools/datasets/libri-tts/"""
items = []
if meta_files is None:
meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True)
for meta_file in meta_files:
_meta_file = os.path.basename(meta_file).split('.')[0]
speaker_name = _meta_file.split('_')[0]
chapter_id = _meta_file.split('_')[1]
_root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
with open(meta_file, 'r') as ttf:
for line in ttf:
cols = line.split('\t')
wav_file = os.path.join(_root_path, cols[0] + '.wav')
text = cols[1]
items.append([text, wav_file, speaker_name])
for item in items:
assert os.path.exists(item[1]), f" [!] wav file is not exist - {item[1]}"
return items

View File

@ -208,7 +208,7 @@ class Attention(nn.Module):
_, n = prev_alpha.max(1)
val, n2 = alpha.max(1)
for b in range(alignment.shape[0]):
alpha[b, n[b] + 2:] = 0
alpha[b, n[b] + 3:] = 0
alpha[b, :(n[b] - 1)] = 0 # ignore all previous states to prevent repetition.
alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step
# compute attention weights

168
layers/gst_layers.py Normal file
View File

@ -0,0 +1,168 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class GST(nn.Module):
"""Global Style Token Module for factorizing prosody in speech.
See https://arxiv.org/pdf/1803.09017"""
def __init__(self, num_mel, num_heads, num_style_tokens, embedding_dim):
super().__init__()
self.encoder = ReferenceEncoder(num_mel, embedding_dim)
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens,
embedding_dim)
def forward(self, inputs):
enc_out = self.encoder(inputs)
style_embed = self.style_token_layer(enc_out)
return style_embed
class ReferenceEncoder(nn.Module):
"""NN module creating a fixed size prosody embedding from a spectrogram.
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
outputs: [batch_size, embedding_dim]
"""
def __init__(self, num_mel, embedding_dim):
super().__init__()
self.num_mel = num_mel
filters = [1] + [32, 32, 64, 64, 128, 128]
num_layers = len(filters) - 1
convs = [
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)) for i in range(num_layers)
]
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList([
nn.BatchNorm2d(num_features=filter_size)
for filter_size in filters[1:]
])
post_conv_height = self.calculate_post_conv_height(
num_mel, 3, 2, 1, num_layers)
self.recurrence = nn.GRU(
input_size=filters[-1] * post_conv_height,
hidden_size=embedding_dim // 2,
batch_first=True)
def forward(self, inputs):
batch_size = inputs.size(0)
x = inputs.view(batch_size, 1, -1, self.num_mel)
# x: 4D tensor [batch_size, num_channels==1, num_frames, num_mel]
for conv, bn in zip(self.convs, self.bns):
x = conv(x)
x = bn(x)
x = F.relu(x)
x = x.transpose(1, 2)
# x: 4D tensor [batch_size, post_conv_width,
# num_channels==128, post_conv_height]
post_conv_width = x.size(1)
x = x.contiguous().view(batch_size, post_conv_width, -1)
# x: 3D tensor [batch_size, post_conv_width,
# num_channels*post_conv_height]
self.recurrence.flatten_parameters()
memory, out = self.recurrence(x)
# out: 3D tensor [seq_len==1, batch_size, encoding_size=128]
return out.squeeze(0)
def calculate_post_conv_height(self, height, kernel_size, stride, pad,
n_convs):
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
for i in range(n_convs):
height = (height - kernel_size + 2 * pad) // stride + 1
return height
class StyleTokenLayer(nn.Module):
"""NN Module attending to style tokens based on prosody encodings."""
def __init__(self, num_heads, num_style_tokens,
embedding_dim):
super().__init__()
self.query_dim = embedding_dim // 2
self.key_dim = embedding_dim // num_heads
self.style_tokens = nn.Parameter(
torch.FloatTensor(num_style_tokens, self.key_dim))
nn.init.orthogonal_(self.style_tokens)
self.attention = MultiHeadAttention(
query_dim=self.query_dim,
key_dim=self.key_dim,
num_units=embedding_dim,
num_heads=num_heads)
def forward(self, inputs):
batch_size = inputs.size(0)
prosody_encoding = inputs.unsqueeze(1)
# prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128]
tokens = torch.tanh(self.style_tokens) \
.unsqueeze(0) \
.expand(batch_size, -1, -1)
# tokens: 3D tensor [batch_size, num tokens, token embedding size]
style_embed = self.attention(prosody_encoding, tokens)
return style_embed
class MultiHeadAttention(nn.Module):
'''
input:
query --- [N, T_q, query_dim]
key --- [N, T_k, key_dim]
output:
out --- [N, T_q, num_units]
'''
def __init__(self, query_dim, key_dim, num_units, num_heads):
super().__init__()
self.num_units = num_units
self.num_heads = num_heads
self.key_dim = key_dim
self.W_query = nn.Linear(
in_features=query_dim, out_features=num_units, bias=False)
self.W_key = nn.Linear(
in_features=key_dim, out_features=num_units, bias=False)
self.W_value = nn.Linear(
in_features=key_dim, out_features=num_units, bias=False)
def forward(self, query, key):
queries = self.W_query(query) # [N, T_q, num_units]
keys = self.W_key(key) # [N, T_k, num_units]
values = self.W_value(key)
split_size = self.num_units // self.num_heads
queries = torch.stack(
torch.split(queries, split_size, dim=2),
dim=0) # [h, N, T_q, num_units/h]
keys = torch.stack(
torch.split(keys, split_size, dim=2),
dim=0) # [h, N, T_k, num_units/h]
values = torch.stack(
torch.split(values, split_size, dim=2),
dim=0) # [h, N, T_k, num_units/h]
# score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim**0.5)
scores = F.softmax(scores, dim=3)
# out = score * V
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
out = torch.cat(
torch.split(out, 1, dim=0),
dim=3).squeeze(0) # [N, T_q, num_units]
return out

View File

@ -9,6 +9,7 @@ from utils.generic_utils import sequence_mask
class Tacotron(nn.Module):
def __init__(self,
num_chars,
num_speakers,
r=5,
linear_dim=1025,
mel_dim=80,
@ -28,6 +29,9 @@ class Tacotron(nn.Module):
self.linear_dim = linear_dim
self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3)
if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 256)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
attn_norm, prenet_type, prenet_dropout,
@ -38,11 +42,13 @@ class Tacotron(nn.Module):
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
nn.Sigmoid())
def forward(self, characters, text_lengths, mel_specs):
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
B = characters.size(0)
mask = sequence_mask(text_lengths).to(characters.device)
inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
@ -50,13 +56,26 @@ class Tacotron(nn.Module):
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens
def inference(self, characters):
def inference(self, characters, speaker_ids=None):
B = characters.size(0)
inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
mel_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens
return mel_outputs, linear_outputs, alignments, stop_tokens
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
speaker_embeddings = self.speaker_embedding(speaker_ids)
speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs = encoder_outputs + speaker_embeddings
return encoder_outputs

View File

@ -11,6 +11,7 @@ from utils.generic_utils import sequence_mask
class Tacotron2(nn.Module):
def __init__(self,
num_chars,
num_speakers,
r,
attn_win=False,
attn_norm="softmax",
@ -28,6 +29,9 @@ class Tacotron2(nn.Module):
std = sqrt(2.0 / (num_chars + 512))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 512)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win,
attn_norm, prenet_type, prenet_dropout,
@ -40,11 +44,13 @@ class Tacotron2(nn.Module):
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments
def forward(self, text, text_lengths, mel_specs=None):
def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None):
# compute mask for padding
mask = sequence_mask(text_lengths).to(text.device)
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
mel_outputs, stop_tokens, alignments = self.decoder(
encoder_outputs, mel_specs, mask)
mel_outputs_postnet = self.postnet(mel_outputs)
@ -53,9 +59,11 @@ class Tacotron2(nn.Module):
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def inference(self, text):
def inference(self, text, speaker_ids=None):
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
mel_outputs, stop_tokens, alignments = self.decoder.inference(
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
@ -64,16 +72,29 @@ class Tacotron2(nn.Module):
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def inference_truncated(self, text):
def inference_truncated(self, text, speaker_ids=None):
"""
Preserve model states for continuous inference
"""
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
speaker_embeddings = self.speaker_embedding(speaker_ids)
speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs = encoder_outputs + speaker_embeddings
return encoder_outputs

90
models/tacotrongst.py Normal file
View File

@ -0,0 +1,90 @@
# coding: utf-8
import torch
from torch import nn
from math import sqrt
from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
from layers.gst_layers import GST
from utils.generic_utils import sequence_mask
class TacotronGST(nn.Module):
def __init__(self,
num_chars,
num_speakers,
r=5,
linear_dim=1025,
mel_dim=80,
memory_size=5,
attn_win=False,
attn_norm="sigmoid",
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True):
super(TacotronGST, self).__init__()
self.r = r
self.mel_dim = mel_dim
self.linear_dim = linear_dim
self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3)
if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 256)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256)
self.gst = GST(num_mel=80, num_heads=4, num_style_tokens=10, embedding_dim=256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask,
location_attn, separate_stopnet)
self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Sequential(
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
nn.Sigmoid())
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
B = characters.size(0)
mask = sequence_mask(text_lengths).to(characters.device)
inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
gst_outputs = self.gst(mel_specs)
gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
encoder_outputs = encoder_outputs + gst_outputs
mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens
def inference(self, characters, speaker_ids=None, style_mel=None):
B = characters.size(0)
inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
if style_mel is not None:
gst_outputs = self.gst(style_mel)
gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
encoder_outputs = encoder_outputs + gst_outputs
mel_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
speaker_embeddings = self.speaker_embedding(speaker_ids)
speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs = encoder_outputs + speaker_embeddings
return encoder_outputs

View File

@ -20,7 +20,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"TTS_PATH = \"/home/erogol/projects/\"\n",
@ -31,6 +33,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
@ -76,12 +79,14 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n",
"def tts(model, text, CONFIG, use_cuda, ap, use_gl, speaker_id=None, figures=True):\n",
" t_1 = time.time()\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, False, CONFIG.enable_eos_bos_chars)\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, speaker_id=speaker_id, enable_eos_bos_chars=CONFIG.enable_eos_bos_chars)\n",
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
" if not use_gl:\n",
@ -101,14 +106,16 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Set constants\n",
"ROOT_PATH = '/media/erogol/data_ssd/Data/models/mozilla_models/4845/'\n",
"MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n",
"MODEL_PATH = ROOT_PATH + 'best_model.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n",
"OUT_FOLDER = \"/home/erogol/Dropbox/AudioSamples/benchmark_samples/\"\n",
"CONFIG = load_config(CONFIG_PATH)\n",
"VOCODER_MODEL_PATH = \"/media/erogol/data_ssd/Data/models/wavernn/mozilla/mozilla-May24-4763/model_checkpoints/best_model.pth.tar\"\n",
"VOCODER_CONFIG_PATH = \"/media/erogol/data_ssd/Data/models/wavernn/mozilla/mozilla-May24-4763/config.json\"\n",
@ -129,7 +136,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# LOAD TTS MODEL\n",
@ -160,7 +169,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# LOAD WAVERNN\n",
@ -202,58 +213,66 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"model.eval()\n",
"model.decoder.max_decoder_steps = 2000\n",
"speaker_id = 0\n",
"sentence = \"Bill got in the habit of asking himself “Is that thought true?” And if he wasnt absolutely certain it was, he just let it go.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"The human voice is the most perfect instrument of all.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"The human voice is the most perfect instrument of all.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"sentence = \"This cake is great. It's so delicious and moist.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
@ -266,51 +285,61 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"Heres a way to measure the acute emotional intelligence that has never gone out of style.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"The buses aren't the problem, they actually provide a solution.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
@ -323,105 +352,123 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \" He has read the whole thing.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"He reads books.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \" He has read the whole thing.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"He reads books.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"sentence = \"Thisss isrealy awhsome.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"sentence = \"This is your internet browser, Firefox.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"This is your internet browser Firefox.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"The quick brown fox jumps over the lazy dog.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"Does the quick brown fox jump over the lazy dog?\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"Eren, how are you?\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
@ -434,70 +481,83 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"Encouraged, he started with a minute a day.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"If he decided to watch TV he really watched it.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sentence = \"If he decided to watch TV he really watched it.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# for twb dataset\n",
"sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, speaker_id=speaker_id, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# !zip benchmark_samples/samples.zip benchmark_samples/*"
@ -506,9 +566,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3(mztts)",
"language": "python",
"name": "python3"
"name": "mztts"
},
"language_info": {
"codemirror_mode": {
@ -520,7 +580,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
"version": "3.6.8"
}
},
"nbformat": 4,

View File

@ -10,3 +10,4 @@ flask
scipy==0.19.0
tqdm
git+git://github.com/bootphon/phonemizer@master
soundfile

111
train.py
View File

@ -1,7 +1,5 @@
import argparse
import importlib
import os
import shutil
import sys
import time
import traceback
@ -25,9 +23,12 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
save_checkpoint, sequence_mask, weight_decay,
set_init_dict, copy_config_file, setup_model)
from utils.logger import Logger
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
get_speakers
from utils.synthesis import synthesis
from utils.text.symbols import phonemes, symbols
from utils.visual import plot_alignment, plot_spectrogram
from datasets.preprocess import get_preprocessor_by_name
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
@ -48,7 +49,7 @@ def setup_loader(is_val=False, verbose=False):
c.meta_file_val if is_val else c.meta_file_train,
c.r,
c.text_cleaner,
preprocessor=preprocessor,
preprocessor=get_preprocessor_by_name(c.dataset),
ap=ap,
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
min_seq_len=0 if is_val else c.min_seq_len,
@ -75,6 +76,8 @@ def setup_loader(is_val=False, verbose=False):
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
ap, epoch):
data_loader = setup_loader(is_val=False, verbose=(epoch==0))
if c.use_speaker_embedding:
speaker_mapping = load_speaker_mapping(OUT_PATH)
model.train()
epoch_time = 0
avg_postnet_loss = 0
@ -89,13 +92,21 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# setup input data
text_input = data[0]
text_lengths = data[1]
linear_input = data[2] if c.model == "Tacotron" else None
mel_input = data[3]
mel_lengths = data[4]
stop_targets = data[5]
speaker_names = data[2]
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
mel_input = data[4]
mel_lengths = data[5]
stop_targets = data[6]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
if c.use_speaker_embedding:
speaker_ids = [speaker_mapping[speaker_name]
for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids)
else:
speaker_ids = None
# set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1)
@ -116,24 +127,26 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda(non_blocking=True) if c.model == "Tacotron" else None
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None
stop_targets = stop_targets.cuda(non_blocking=True)
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True)
# forward pass model
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input)
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
# loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model == "Tacotron":
if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
else:
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
else:
decoder_loss = criterion(decoder_output, mel_input)
if c.model == "Tacotron":
if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input)
else:
postnet_loss = criterion(postnet_output, mel_input)
@ -178,7 +191,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
if args.rank == 0:
avg_postnet_loss += float(postnet_loss.item())
avg_decoder_loss += float(decoder_loss.item())
avg_stop_loss += stop_loss if type(stop_loss) is float else float(stop_loss.item())
avg_stop_loss += stop_loss if type(stop_loss) is float else float(stop_loss.item())
avg_step_time += step_time
# Plot Training Iter Stats
@ -199,7 +212,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy() if c.model == "Tacotron" else mel_input[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy() if c.model in ["Tacotron", "TacotronGST"] else mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
figures = {
@ -210,7 +223,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
tb_logger.tb_train_figures(current_step, figures)
# Sample audio
if c.model == "Tacotron":
if c.model in ["Tacotron", "TacotronGST"]:
train_audio = ap.inv_spectrogram(const_spec.T)
else:
train_audio = ap.inv_mel_spectrogram(const_spec.T)
@ -243,12 +256,15 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
"epoch_time": epoch_time}
tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
if c.tb_model_param_stats:
tb_logger.tb_model_weights(model, current_step)
tb_logger.tb_model_weights(model, current_step)
return avg_postnet_loss, current_step
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
data_loader = setup_loader(is_val=True)
if c.use_speaker_embedding:
speaker_mapping = load_speaker_mapping(OUT_PATH)
model.eval()
epoch_time = 0
avg_postnet_loss = 0
@ -273,10 +289,18 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
# setup input data
text_input = data[0]
text_lengths = data[1]
linear_input = data[2] if c.model == "Tacotron" else None
mel_input = data[3]
mel_lengths = data[4]
stop_targets = data[5]
speaker_names = data[2]
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
mel_input = data[4]
mel_lengths = data[5]
stop_targets = data[6]
if c.use_speaker_embedding:
speaker_ids = [speaker_mapping[speaker_name]
for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids)
else:
speaker_ids = None
# set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0],
@ -289,24 +313,27 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
text_input = text_input.cuda()
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda() if c.model == "Tacotron" else None
linear_input = linear_input.cuda() if c.model in ["Tacotron", "TacotronGST"] else None
stop_targets = stop_targets.cuda()
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda()
# forward pass
decoder_output, postnet_output, alignments, stop_tokens =\
model.forward(text_input, text_lengths, mel_input)
model.forward(text_input, text_lengths, mel_input,
speaker_ids=speaker_ids)
# loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model == "Tacotron":
if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
else:
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
else:
decoder_loss = criterion(decoder_output, mel_input)
if c.model == "Tacotron":
if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input)
else:
postnet_loss = criterion(postnet_output, mel_input)
@ -339,7 +366,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
# Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0])
const_spec = postnet_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() if c.model == "Tacotron" else mel_input[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in ["Tacotron", "TacotronGST"] else mel_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy()
eval_figures = {
@ -350,7 +377,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
tb_logger.tb_eval_figures(current_step, eval_figures)
# Sample audio
if c.model == "Tacotron":
if c.model in ["Tacotron", "TacotronGST"]:
eval_audio = ap.inv_spectrogram(const_spec.T)
else:
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
@ -372,10 +399,12 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
test_audios = {}
test_figures = {}
print(" | > Synthesizing test sentences")
speaker_id = 0 if c.num_speakers > 1 else None
for idx, test_sentence in enumerate(test_sentences):
try:
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
model, test_sentence, c, use_cuda, ap)
model, test_sentence, c, use_cuda, ap,
speaker_id=speaker_id)
file_path = os.path.join(AUDIO_PATH, str(current_step))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path,
@ -398,7 +427,27 @@ def main(args):
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, c)
if c.use_speaker_embedding:
speakers = get_speakers(c.data_path, c.meta_file_train, c.dataset)
if args.restore_path:
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
else:
speaker_mapping = {name: i
for i, name in enumerate(speakers)}
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
", ".join(speakers)))
else:
num_speakers = 0
model = setup_model(num_chars, num_speakers, c)
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
@ -410,9 +459,9 @@ def main(args):
optimizer_st = None
if c.loss_masking:
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"] else MSELossMasked()
else:
criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss()
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"] else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None
if args.restore_path:
@ -558,10 +607,6 @@ if __name__ == '__main__':
LOG_DIR = OUT_PATH
tb_logger = Logger(LOG_DIR)
# Conditional imports
preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, c.dataset.lower())
# Audio processor
ap = AudioProcessor(**c.audio)

View File

@ -10,7 +10,6 @@ from scipy import signal, io
class AudioProcessor(object):
def __init__(self,
bits=None,
sample_rate=None,
num_mels=None,
min_level_db=None,
@ -32,7 +31,6 @@ class AudioProcessor(object):
print(" > Setting up Audio Processor...")
self.bits = bits
self.sample_rate = sample_rate
self.num_mels = num_mels
self.min_level_db = min_level_db
@ -231,23 +229,27 @@ class AudioProcessor(object):
def mulaw_decode(wav, qc):
"""Recovers waveform from quantized values."""
mu = 2 ** qc - 1
mu = mu - 1
signal = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return signal
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
def load_wav(self, filename, encode=False):
x, sr = sf.read(filename)
def load_wav(self, filename, sr=None):
if sr is None:
x, sr = sf.read(filename)
else:
x, sr = librosa.load(filename, sr=sr)
if self.do_trim_silence:
x = self.trim_silence(x)
# sr, x = io.wavfile.read(filename)
try:
x = self.trim_silence(x)
except ValueError as e:
print(f' [!] File cannot be trimmed for silence - {filename}')
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
return x
def encode_16bits(self, x):
return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
def quantize(self, x):
return (x + 1.) * (2**self.bits - 1) / 2
def quantize(self, x, bits):
return (x + 1.) * (2**bits - 1) / 2
def dequantize(self, x):
return 2 * x / (2**self.bits - 1) - 1
def dequantize(self, x, bits):
return 2 * x / (2**bits - 1) - 1

View File

@ -33,9 +33,13 @@ def load_config(config_path):
def get_git_branch():
out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split("\n") if line.startswith("*"))
return current.replace("* ", "")
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split("\n") if line.startswith("*"))
current.replace("* ", "")
except subprocess.CalledProcessError:
current = "inside_docker"
return current
def get_commit_hash():
@ -46,8 +50,12 @@ def get_commit_hash():
# except:
# raise RuntimeError(
# " !! Commit before training to get the commit hash.")
commit = subprocess.check_output(['git', 'rev-parse', '--short',
'HEAD']).decode().strip()
try:
commit = subprocess.check_output(['git', 'rev-parse', '--short',
'HEAD']).decode().strip()
# Not copying .git folder into docker container
except subprocess.CalledProcessError:
commit = "0000000"
print(' > Git Hash: {}'.format(commit))
return commit
@ -243,13 +251,14 @@ def set_init_dict(model_dict, checkpoint, c):
return model_dict
def setup_model(num_chars, c):
def setup_model(num_chars, num_speakers, c):
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('models.' + c.model.lower())
MyModel = getattr(MyModel, c.model)
if c.model.lower() == "tacotron":
if c.model.lower() in ["tacotron", "tacotrongst"]:
model = MyModel(
num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
linear_dim=1025,
mel_dim=80,
@ -266,6 +275,7 @@ def setup_model(num_chars, c):
elif c.model.lower() == "tacotron2":
model = MyModel(
num_chars=num_chars,
num_speakers=c.num_speakers,
r=c.r,
attn_win=c.windowing,
attn_norm=c.attention_norm,
@ -276,4 +286,4 @@ def setup_model(num_chars, c):
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet)
return model
return model

33
utils/speakers.py Normal file
View File

@ -0,0 +1,33 @@
import os
import json
from datasets.preprocess import get_preprocessor_by_name
def make_speakers_json_path(out_path):
"""Returns conventional speakers.json location."""
return os.path.join(out_path, "speakers.json")
def load_speaker_mapping(out_path):
"""Loads speaker mapping if already present."""
try:
with open(make_speakers_json_path(out_path)) as f:
return json.load(f)
except FileNotFoundError:
return {}
def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present."""
speakers_json_path = make_speakers_json_path(out_path)
with open(speakers_json_path, "w") as f:
json.dump(speaker_mapping, f, indent=4)
def get_speakers(data_root, meta_file, dataset_type):
"""Returns a sorted, unique list of speakers in a given dataset."""
preprocessor = get_preprocessor_by_name(dataset_type)
items = preprocessor(data_root, meta_file)
speakers = {e[2] for e in items}
return sorted(speakers)

View File

@ -8,7 +8,83 @@ from .visual import visualize
from matplotlib import pylab as plt
def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos_chars=False, trim_silence=False):
def text_to_seqvec(text, CONFIG, use_cuda):
text_cleaner = [CONFIG.text_cleaner]
# text ot phonemes to sequence vector
if CONFIG.use_phonemes:
seq = np.asarray(
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars),
dtype=np.int32)
else:
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
# torch tensor
chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda:
chars_var = chars_var.cuda()
return chars_var.long()
def compute_style_mel(style_wav, ap, use_cuda):
print(style_wav)
style_mel = torch.FloatTensor(ap.melspectrogram(
ap.load_wav(style_wav))).unsqueeze(0)
if use_cuda:
return style_mel.cuda()
else:
return style_mel
def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
if CONFIG.model == "TacotronGST" and style_mel is not None:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, style_mel=style_mel, speaker_ids=speaker_id)
else:
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id)
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id)
return decoder_output, postnet_output, alignments, stop_tokens
def parse_outputs(postnet_output, decoder_output, alignments):
postnet_output = postnet_output[0].data.cpu().numpy()
decoder_output = decoder_output[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()
return postnet_output, decoder_output, alignment
def trim_silence(wav):
return wav[:ap.find_endpoint(wav)]
def inv_spectrogram(postnet_output, ap, CONFIG):
if CONFIG.model in ["Tacotron", "TacotronGST"]:
wav = ap.inv_spectrogram(postnet_output.T)
else:
wav = ap.inv_mel_spectrogram(postnet_output.T)
return wav
def id_to_torch(speaker_id):
if speaker_id is not None:
speaker_id = np.asarray(speaker_id)
speaker_id = torch.from_numpy(speaker_id).unsqueeze(0)
return speaker_id
def synthesis(model,
text,
CONFIG,
use_cuda,
ap,
speaker_id=None,
style_wav=None,
truncated=False,
enable_eos_bos_chars=False,
trim_silence=False):
"""Synthesize voice for the given text.
Args:
@ -18,39 +94,31 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos
use_cuda (bool): enable cuda.
ap (TTS.utils.audio.AudioProcessor): audio processor to process
model outputs.
speaker_id (int): id of speaker
style_wav (str): Uses for style embedding of GST.
truncated (bool): keep model states after inference. It can be used
for continuous inference at long texts.
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
trim_silence (bool): trim silence after synthesis.
"""
# GST processing
style_mel = None
if CONFIG.model == "TacotronGST" and style_wav is not None:
style_mel = compute_style_mel(style_wav, ap, use_cuda)
# preprocess the given text
text_cleaner = [CONFIG.text_cleaner]
if CONFIG.use_phonemes:
seq = np.asarray(
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, enable_eos_bos_chars),
dtype=np.int32)
else:
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
chars_var = torch.from_numpy(seq).unsqueeze(0)
# synthesize voice
inputs = text_to_seqvec(text, CONFIG, use_cuda)
speaker_id = id_to_torch(speaker_id)
if use_cuda:
chars_var = chars_var.cuda()
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
chars_var.long())
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
chars_var.long())
speaker_id.cuda()
# synthesize voice
decoder_output, postnet_output, alignments, stop_tokens = run_model(
model, inputs, CONFIG, truncated, speaker_id, style_mel)
# convert outputs to numpy
postnet_output = postnet_output[0].data.cpu().numpy()
decoder_output = decoder_output[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()
postnet_output, decoder_output, alignment = parse_outputs(
postnet_output, decoder_output, alignments)
# plot results
if CONFIG.model == "Tacotron":
wav = ap.inv_spectrogram(postnet_output.T)
else:
wav = ap.inv_mel_spectrogram(postnet_output.T)
wav = inv_spectrogram(postnet_output, ap, CONFIG)
# trim silence
if trim_silence:
wav = wav[:ap.find_endpoint(wav)]
return wav, alignment, decoder_output, postnet_output, stop_tokens
wav = trim_silence(wav)
return wav, alignment, decoder_output, postnet_output, stop_tokens