diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 3222c278..57a17704 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -14,6 +14,7 @@ from TTS.speaker_encoder.dataset import MyDataset from TTS.speaker_encoder.generic_utils import save_best_model from TTS.speaker_encoder.losses import GE2ELoss, AngleProtoLoss from TTS.speaker_encoder.model import SpeakerEncoder +from TTS.speaker_encoder.utils import check_config_speaker_encoder from TTS.speaker_encoder.visual import plot_embeddings from TTS.tts.datasets.preprocess import load_meta_data from TTS.utils.generic_utils import ( @@ -235,6 +236,7 @@ if __name__ == '__main__': # setup output paths and read configs c = load_config(args.config_path) + check_config_speaker_encoder(c) _ = os.path.dirname(os.path.realpath(__file__)) if args.data_path != '': c.data_path = args.data_path diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index f2641f9d..1b7351d4 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -17,7 +17,7 @@ from TTS.tts.layers.losses import TacotronLoss from TTS.tts.utils.distribute import (DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor) -from TTS.tts.utils.generic_utils import check_config, setup_model +from TTS.tts.utils.generic_utils import setup_model, check_config_tts from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping, @@ -670,7 +670,7 @@ if __name__ == '__main__': # setup output paths and read configs c = load_config(args.config_path) - check_config(c) + check_config_tts(c) _ = os.path.dirname(os.path.realpath(__file__)) if c.apex_amp_level == 'O1': diff --git a/TTS/speaker_encoder/utils.py b/TTS/speaker_encoder/utils.py new file mode 100644 index 00000000..95c222f2 --- /dev/null +++ b/TTS/speaker_encoder/utils.py @@ -0,0 +1,61 @@ +from TTS.utils.generic_utils import check_argument + + +def check_config_speaker_encoder(c): + """Check the config.json file of the speaker encoder""" + check_argument('run_name', c, restricted=True, val_type=str) + check_argument('run_description', c, val_type=str) + + # audio processing parameters + check_argument('audio', c, restricted=True, val_type=dict) + check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) + check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) + check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) + check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') + check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') + check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) + check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) + check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) + check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) + check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + + # training parameters + check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str) + check_argument('grad_clip', c, restricted=True, val_type=float) + check_argument('epochs', c, restricted=True, val_type=int, min_val=1) + check_argument('lr', c, restricted=True, val_type=float, min_val=0) + check_argument('lr_decay', c, restricted=True, val_type=bool) + check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) + check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) + check_argument('num_speakers_in_batch', c, restricted=True, val_type=int) + check_argument('num_loader_workers', c, restricted=True, val_type=int) + check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0) + + # checkpoint and output parameters + check_argument('steps_plot_stats', c, restricted=True, val_type=int) + check_argument('checkpoint', c, restricted=True, val_type=bool) + check_argument('save_step', c, restricted=True, val_type=int) + check_argument('print_step', c, restricted=True, val_type=int) + check_argument('output_path', c, restricted=True, val_type=str) + + # model parameters + check_argument('model', c, restricted=True, val_type=dict) + check_argument('input_dim', c['model'], restricted=True, val_type=int) + check_argument('proj_dim', c['model'], restricted=True, val_type=int) + check_argument('lstm_dim', c['model'], restricted=True, val_type=int) + check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int) + check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool) + + # in-memory storage parameters + check_argument('storage', c, restricted=True, val_type=dict) + check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) + check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100) + check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) + + # datasets - checking only the first entry + check_argument('datasets', c, restricted=True, val_type=list) + for dataset_entry in c['datasets']: + check_argument('name', dataset_entry, restricted=True, val_type=str) + check_argument('path', dataset_entry, restricted=True, val_type=str) + check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) + check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index af32a769..f9d21644 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -100,7 +100,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): return model -def check_config(c): +def check_config_tts(c): check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) check_argument('run_name', c, restricted=True, val_type=str) check_argument('run_description', c, val_type=str) @@ -140,12 +140,6 @@ def check_config(c): check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) check_argument('trim_db', c['audio'], restricted=True, val_type=int) - # storage parameters (only for speaker encoder) - if 'storage' in c.keys(): - check_argument('sample_from_storage_p', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0) - check_argument('storage_size', c['storage'], restricted=False, val_type=int, min_val=1, max_val=100) - check_argument('additive_noise', c['storage'], restricted=False, val_type=float, min_val=0.0, max_val=1.0) - # training parameters check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) diff --git a/requirements.txt b/requirements.txt index f0f2c057..85ffbe55 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ nose==1.3.7 cardboardlint==1.3.0 pylint==2.5.3 gdown +umap \ No newline at end of file