From a2859b7ddcb8c7b6bebc801111a68aabfa9c261f Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 10 Dec 2020 13:45:30 +0100 Subject: [PATCH] update config args checks --- TTS/tts/utils/generic_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 48b00fa5..928c9dfc 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -177,7 +177,7 @@ def check_config_tts(c): check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) check_argument('r', c, restricted=True, val_type=int, min_val=1) check_argument('gradual_training', c, restricted=False, val_type=list) - check_argument('apex_amp_level', c, restricted=False, val_type=str) + check_argument('mixed_precision', c, restricted=False, val_type=bool) # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) # loss parameters @@ -224,9 +224,10 @@ def check_config_tts(c): check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool) check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int) - # stopnet - check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool) - check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool) + if c['model'].lower() in ['tacotron', 'tacotron2']: + # stopnet + check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool) + check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool) # GlowTTS parameters check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str) @@ -257,8 +258,8 @@ def check_config_tts(c): check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool) check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str) - check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool) if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']: + check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool) check_argument('gst', c, restricted=is_tacotron(c), val_type=dict) check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict]) check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)