config update, check arguments update and enable alternative arguments

This commit is contained in:
erogol 2020-02-20 12:48:45 +01:00
parent f707460886
commit e6504cc9a4
2 changed files with 8 additions and 7 deletions

View File

@ -1,6 +1,6 @@
{
"model": "Tacotron2", // one of the model in models/
"run_name": "ljspeech-stf_params",
"run_name": "ljspeech-stft_params",
"run_description": "tacotron2 cosntant stf parameters",
// AUDIO PARAMETERS
@ -36,12 +36,11 @@
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
// TRAINING
"batch_size": 2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":16,
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"gradual_training": [[0, 7, 64], [2000, 5, 64], [35000, 3, 32], [70000, 2, 32], [140000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"grad_accum": 2, // if N > 1, enable gradient accumulation for N iterations. It is useful for low memory GPUs.
// VALIDATION
"run_eval": true,

View File

@ -391,7 +391,9 @@ class KeepAverage():
self.update_value(key, value)
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None):
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None):
if alternative in c.keys() and c[alternative] is not None:
return
if restricted:
assert name in c.keys(), f' [!] {name} not defined in config.json'
if name in c.keys():
@ -417,8 +419,8 @@ def check_config(c):
_check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
_check_argument('num_freq', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
_check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000)
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000)
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
_check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
_check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
_check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)