From 30788960a85247eb56d21062692c961a74d1fa1b Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 28 Dec 2020 13:53:10 +0100 Subject: [PATCH] check SS model parameters --- TTS/tts/utils/generic_utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 0f442694..7758ce08 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -198,6 +198,10 @@ def check_config_tts(c): check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0) check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0) check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0) + if c['model'].lower == "speedy_speech": + check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0) # validation parameters check_argument('run_eval', c, restricted=True, val_type=bool) @@ -209,9 +213,9 @@ def check_config_tts(c): check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0) check_argument('epochs', c, restricted=True, val_type=int, min_val=1) check_argument('lr', c, restricted=True, val_type=float, min_val=0) - check_argument('wd', c, restricted=True, val_type=float, min_val=0) + check_argument('wd', c, restricted=is_tacotron(c), val_type=float, min_val=0) check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) - check_argument('seq_len_norm', c, restricted=True, val_type=bool) + check_argument('seq_len_norm', c, restricted=is_tacotron(c), val_type=bool) # tacotron prenet check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1) @@ -237,6 +241,13 @@ def check_config_tts(c): check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool) check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool) + # Model Parameters for non-tacotron models + if c['model'].lower == "speedy_speech": + check_argument('positional_encoding', c, restricted=True, val_type=type) + check_argument('encoder_type', c, restricted=True, val_type=str) + check_argument('encoder_params', c, restricted=True, val_type=dict) + check_argument('decoder_residual_conv_bn_params', c, restricted=True, val_type=dict) + # GlowTTS parameters check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)