From ae6405bb76baece21a69c65cc45325855f223a80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 28 Jun 2021 16:53:57 +0200 Subject: [PATCH] Docstrings for `Trainer` --- TTS/trainer.py | 68 ++++++++++++++++++++++++++--- TTS/tts/layers/tacotron/tacotron.py | 2 + TTS/tts/utils/text/__init__.py | 1 + 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index e3403bae..0e921335 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -4,6 +4,7 @@ import glob import importlib import logging import os +import platform import re import sys import time @@ -40,13 +41,21 @@ from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_availa from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model as setup_vocoder_model +if platform.system() != "Windows": + # https://github.com/pytorch/pytorch/issues/973 + import resource + + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + if is_apex_available(): from apex import amp @dataclass class TrainingArgs(Coqpit): - """Trainer arguments""" + """Trainer arguments to be defined externally. It helps integrating the `Trainer` with the higher level APIs and + set the values for distributed training.""" continue_path: str = field( default="", @@ -360,7 +369,7 @@ class Trainer: return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus) def format_batch(self, batch: List) -> Dict: - """Format dataloader ouput and return a batch. + """Format the dataloader output and return a batch. Args: batch (List): Batch returned by the dataloader. @@ -633,7 +642,7 @@ class Trainer: return outputs, loss_dict def train_epoch(self) -> None: - """Main entry point for training. Run training on the whole training samples.""" + """Main entry point for the training loop. Run training on the all training samples.""" self.train_loader = self.get_train_dataloader( self.ap, self.data_train, @@ -682,6 +691,15 @@ class Trainer: return model.eval_step(*input_args) def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: + """Perform a evaluation step on a batch of inputs and log the process. + + Args: + batch (Dict): Input batch. + step (int): Current step number in this epoch. + + Returns: + Tuple[Dict, Dict]: Model outputs and losses. + """ with torch.no_grad(): outputs_per_optimizer = None loss_dict = {} @@ -708,6 +726,7 @@ class Trainer: return outputs, loss_dict def eval_epoch(self) -> None: + """Main entry point for the evaluation loop. Run evaluation on the all validation samples.""" self.eval_loader = ( self.get_eval_dataloader( self.ap, @@ -743,7 +762,7 @@ class Trainer: def test_run(self) -> None: """Run test and log the results. Test run must be defined by the model. - Model must return figures and audios to be logged by the Tensorboard logger.""" + Model must return figures and audios to be logged by the Tensorboard.""" if hasattr(self.model, "test_run"): if hasattr(self.eval_loader.load_test_samples): samples = self.eval_loader.load_test_samples(1) @@ -820,11 +839,22 @@ class Trainer: ) @staticmethod - def _is_apex_available(): + def _is_apex_available() -> bool: + """Check if Nvidia's APEX is available.""" return importlib.util.find_spec("apex") is not None @staticmethod def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]: + """Receive the optimizer from the model if model implements `get_optimizer()` else + check the optimizer parameters in the config and try initiating the optimizer. + + Args: + model (nn.Module): Training model. + config (Coqpit): Training configuration. + + Returns: + Union[torch.optim.Optimizer, List]: A optimizer or a list of optimizers. GAN models define a list. + """ if hasattr(model, "get_optimizer"): optimizer = model.get_optimizer() if optimizer is None: @@ -835,6 +865,16 @@ class Trainer: @staticmethod def get_lr(model: nn.Module, config: Coqpit) -> Union[float, List[float]]: + """Set the initial learning rate by the model if model implements `get_lr()` else try setting the learning rate + fromthe config. + + Args: + model (nn.Module): Training model. + config (Coqpit): Training configuration. + + Returns: + Union[float, List[float]]: A single learning rate or a list of learning rates, one for each optimzier. + """ lr = None if hasattr(model, "get_lr"): lr = model.get_lr() @@ -846,6 +886,16 @@ class Trainer: def get_scheduler( model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, List] ) -> Union[torch.optim.lr_scheduler._LRScheduler, List]: # pylint: disable=protected-access + """Receive the scheduler from the model if model implements `get_scheduler()` else + check the config and try initiating the scheduler. + + Args: + model (nn.Module): Training model. + config (Coqpit): Training configuration. + + Returns: + Union[torch.optim.Optimizer, List]: A scheduler or a list of schedulers, one for each optimizer. + """ scheduler = None if hasattr(model, "get_scheduler"): scheduler = model.get_scheduler(optimizer) @@ -857,6 +907,14 @@ class Trainer: @staticmethod def get_criterion(model: nn.Module) -> nn.Module: + """Receive the criterion from the model. Model must implement `get_criterion()`. + + Args: + model (nn.Module): Training model. + + Returns: + nn.Module: Criterion layer. + """ criterion = None criterion = model.get_criterion() return criterion diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py index a6579171..47b5ea7e 100644 --- a/TTS/tts/layers/tacotron/tacotron.py +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -1,4 +1,6 @@ # coding: utf-8 +# adapted from https://github.com/r9y9/tacotron_pytorch + import torch from torch import nn diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index 787394b5..fdccf7f1 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +# adapted from https://github.com/keithito/tacotron import re import unicodedata