Docstrings for `Trainer`

This commit is contained in:
Eren Gölge 2021-06-28 16:53:57 +02:00
parent 65958eaa41
commit ae6405bb76
3 changed files with 66 additions and 5 deletions

View File

@ -4,6 +4,7 @@ import glob
import importlib
import logging
import os
import platform
import re
import sys
import time
@ -40,13 +41,21 @@ from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_availa
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model
if platform.system() != "Windows":
# https://github.com/pytorch/pytorch/issues/973
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
if is_apex_available():
from apex import amp
@dataclass
class TrainingArgs(Coqpit):
"""Trainer arguments"""
"""Trainer arguments to be defined externally. It helps integrating the `Trainer` with the higher level APIs and
set the values for distributed training."""
continue_path: str = field(
default="",
@ -360,7 +369,7 @@ class Trainer:
return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus)
def format_batch(self, batch: List) -> Dict:
"""Format dataloader ouput and return a batch.
"""Format the dataloader output and return a batch.
Args:
batch (List): Batch returned by the dataloader.
@ -633,7 +642,7 @@ class Trainer:
return outputs, loss_dict
def train_epoch(self) -> None:
"""Main entry point for training. Run training on the whole training samples."""
"""Main entry point for the training loop. Run training on the all training samples."""
self.train_loader = self.get_train_dataloader(
self.ap,
self.data_train,
@ -682,6 +691,15 @@ class Trainer:
return model.eval_step(*input_args)
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
"""Perform a evaluation step on a batch of inputs and log the process.
Args:
batch (Dict): Input batch.
step (int): Current step number in this epoch.
Returns:
Tuple[Dict, Dict]: Model outputs and losses.
"""
with torch.no_grad():
outputs_per_optimizer = None
loss_dict = {}
@ -708,6 +726,7 @@ class Trainer:
return outputs, loss_dict
def eval_epoch(self) -> None:
"""Main entry point for the evaluation loop. Run evaluation on the all validation samples."""
self.eval_loader = (
self.get_eval_dataloader(
self.ap,
@ -743,7 +762,7 @@ class Trainer:
def test_run(self) -> None:
"""Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard logger."""
Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"):
if hasattr(self.eval_loader.load_test_samples):
samples = self.eval_loader.load_test_samples(1)
@ -820,11 +839,22 @@ class Trainer:
)
@staticmethod
def _is_apex_available():
def _is_apex_available() -> bool:
"""Check if Nvidia's APEX is available."""
return importlib.util.find_spec("apex") is not None
@staticmethod
def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]:
"""Receive the optimizer from the model if model implements `get_optimizer()` else
check the optimizer parameters in the config and try initiating the optimizer.
Args:
model (nn.Module): Training model.
config (Coqpit): Training configuration.
Returns:
Union[torch.optim.Optimizer, List]: A optimizer or a list of optimizers. GAN models define a list.
"""
if hasattr(model, "get_optimizer"):
optimizer = model.get_optimizer()
if optimizer is None:
@ -835,6 +865,16 @@ class Trainer:
@staticmethod
def get_lr(model: nn.Module, config: Coqpit) -> Union[float, List[float]]:
"""Set the initial learning rate by the model if model implements `get_lr()` else try setting the learning rate
fromthe config.
Args:
model (nn.Module): Training model.
config (Coqpit): Training configuration.
Returns:
Union[float, List[float]]: A single learning rate or a list of learning rates, one for each optimzier.
"""
lr = None
if hasattr(model, "get_lr"):
lr = model.get_lr()
@ -846,6 +886,16 @@ class Trainer:
def get_scheduler(
model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, List]
) -> Union[torch.optim.lr_scheduler._LRScheduler, List]: # pylint: disable=protected-access
"""Receive the scheduler from the model if model implements `get_scheduler()` else
check the config and try initiating the scheduler.
Args:
model (nn.Module): Training model.
config (Coqpit): Training configuration.
Returns:
Union[torch.optim.Optimizer, List]: A scheduler or a list of schedulers, one for each optimizer.
"""
scheduler = None
if hasattr(model, "get_scheduler"):
scheduler = model.get_scheduler(optimizer)
@ -857,6 +907,14 @@ class Trainer:
@staticmethod
def get_criterion(model: nn.Module) -> nn.Module:
"""Receive the criterion from the model. Model must implement `get_criterion()`.
Args:
model (nn.Module): Training model.
Returns:
nn.Module: Criterion layer.
"""
criterion = None
criterion = model.get_criterion()
return criterion

View File

@ -1,4 +1,6 @@
# coding: utf-8
# adapted from https://github.com/r9y9/tacotron_pytorch
import torch
from torch import nn

View File

@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
# adapted from https://github.com/keithito/tacotron
import re
import unicodedata