Add `scheduler_after_epoch` to `BaseTrainingConfig`

This commit is contained in:
Eren Gölge 2021-08-07 21:43:07 +00:00
parent e4648ffef1
commit 960a35a121
1 changed files with 24 additions and 1 deletions

View File

@ -79,7 +79,7 @@ class BaseAudioConfig(Coqpit):
preemphasis: float = 0.0 preemphasis: float = 0.0
ref_level_db: int = 20 ref_level_db: int = 20
do_sound_norm: bool = False do_sound_norm: bool = False
log_func = "np.log10" log_func: str = "np.log10"
# silence trimming # silence trimming
do_trim_silence: bool = True do_trim_silence: bool = True
trim_db: int = 45 trim_db: int = 45
@ -182,48 +182,70 @@ class BaseTrainingConfig(Coqpit):
Args: Args:
model (str): model (str):
Name of the model that is used in the training. Name of the model that is used in the training.
run_name (str): run_name (str):
Name of the experiment. This prefixes the output folder name. Name of the experiment. This prefixes the output folder name.
run_description (str): run_description (str):
Short description of the experiment. Short description of the experiment.
epochs (int): epochs (int):
Number training epochs. Defaults to 10000. Number training epochs. Defaults to 10000.
batch_size (int): batch_size (int):
Training batch size. Training batch size.
eval_batch_size (int): eval_batch_size (int):
Validation batch size. Validation batch size.
mixed_precision (bool): mixed_precision (bool):
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
it may also cause numerical unstability in some cases. it may also cause numerical unstability in some cases.
scheduler_after_epoch (bool):
If true, run the scheduler step after each epoch else run it after each model step.
run_eval (bool): run_eval (bool):
Enable / Disable evaluation (validation) run. Defaults to True. Enable / Disable evaluation (validation) run. Defaults to True.
test_delay_epochs (int): test_delay_epochs (int):
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
results, hence waiting for a couple of epochs might save some time. results, hence waiting for a couple of epochs might save some time.
print_eval (bool): print_eval (bool):
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
the end of the evaluation. Default to ```False```. the end of the evaluation. Default to ```False```.
print_step (int): print_step (int):
Number of steps required to print the next training log. Number of steps required to print the next training log.
tb_plot_step (int): tb_plot_step (int):
Number of steps required to log training on Tensorboard. Number of steps required to log training on Tensorboard.
tb_model_param_stats (bool): tb_model_param_stats (bool):
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging. Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
Defaults to ```False```. Defaults to ```False```.
save_step (int):ipt save_step (int):ipt
Number of steps required to save the next checkpoint. Number of steps required to save the next checkpoint.
checkpoint (bool): checkpoint (bool):
Enable / Disable checkpointing. Enable / Disable checkpointing.
keep_all_best (bool): keep_all_best (bool):
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
to ```False```. to ```False```.
keep_after (int): keep_after (int):
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
to 10000. to 10000.
num_loader_workers (int): num_loader_workers (int):
Number of workers for training time dataloader. Number of workers for training time dataloader.
num_eval_loader_workers (int): num_eval_loader_workers (int):
Number of workers for evaluation time dataloader. Number of workers for evaluation time dataloader.
output_path (str): output_path (str):
Path for training output folder, either a local file path or other Path for training output folder, either a local file path or other
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
@ -239,6 +261,7 @@ class BaseTrainingConfig(Coqpit):
batch_size: int = None batch_size: int = None
eval_batch_size: int = None eval_batch_size: int = None
mixed_precision: bool = False mixed_precision: bool = False
scheduler_after_epoch: bool = False
# eval params # eval params
run_eval: bool = True run_eval: bool = True
test_delay_epochs: int = 0 test_delay_epochs: int = 0