check SS model parameters

This commit is contained in:
erogol 2020-12-28 13:53:10 +01:00
parent 5cae2c5742
commit 30788960a8
1 changed files with 13 additions and 2 deletions

View File

@ -198,6 +198,10 @@ def check_config_tts(c):
check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
if c['model'].lower == "speedy_speech":
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
# validation parameters
check_argument('run_eval', c, restricted=True, val_type=bool)
@ -209,9 +213,9 @@ def check_config_tts(c):
check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
check_argument('lr', c, restricted=True, val_type=float, min_val=0)
check_argument('wd', c, restricted=True, val_type=float, min_val=0)
check_argument('wd', c, restricted=is_tacotron(c), val_type=float, min_val=0)
check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
check_argument('seq_len_norm', c, restricted=True, val_type=bool)
check_argument('seq_len_norm', c, restricted=is_tacotron(c), val_type=bool)
# tacotron prenet
check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
@ -237,6 +241,13 @@ def check_config_tts(c):
check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
# Model Parameters for non-tacotron models
if c['model'].lower == "speedy_speech":
check_argument('positional_encoding', c, restricted=True, val_type=type)
check_argument('encoder_type', c, restricted=True, val_type=str)
check_argument('encoder_params', c, restricted=True, val_type=dict)
check_argument('decoder_residual_conv_bn_params', c, restricted=True, val_type=dict)
# GlowTTS parameters
check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)