mirror of https://github.com/coqui-ai/TTS.git
Docstrings for `Trainer`
This commit is contained in:
parent
65958eaa41
commit
ae6405bb76
|
@ -4,6 +4,7 @@ import glob
|
|||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
@ -40,13 +41,21 @@ from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_availa
|
|||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# https://github.com/pytorch/pytorch/issues/973
|
||||
import resource
|
||||
|
||||
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArgs(Coqpit):
|
||||
"""Trainer arguments"""
|
||||
"""Trainer arguments to be defined externally. It helps integrating the `Trainer` with the higher level APIs and
|
||||
set the values for distributed training."""
|
||||
|
||||
continue_path: str = field(
|
||||
default="",
|
||||
|
@ -360,7 +369,7 @@ class Trainer:
|
|||
return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus)
|
||||
|
||||
def format_batch(self, batch: List) -> Dict:
|
||||
"""Format dataloader ouput and return a batch.
|
||||
"""Format the dataloader output and return a batch.
|
||||
|
||||
Args:
|
||||
batch (List): Batch returned by the dataloader.
|
||||
|
@ -633,7 +642,7 @@ class Trainer:
|
|||
return outputs, loss_dict
|
||||
|
||||
def train_epoch(self) -> None:
|
||||
"""Main entry point for training. Run training on the whole training samples."""
|
||||
"""Main entry point for the training loop. Run training on the all training samples."""
|
||||
self.train_loader = self.get_train_dataloader(
|
||||
self.ap,
|
||||
self.data_train,
|
||||
|
@ -682,6 +691,15 @@ class Trainer:
|
|||
return model.eval_step(*input_args)
|
||||
|
||||
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
||||
"""Perform a evaluation step on a batch of inputs and log the process.
|
||||
|
||||
Args:
|
||||
batch (Dict): Input batch.
|
||||
step (int): Current step number in this epoch.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Model outputs and losses.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
outputs_per_optimizer = None
|
||||
loss_dict = {}
|
||||
|
@ -708,6 +726,7 @@ class Trainer:
|
|||
return outputs, loss_dict
|
||||
|
||||
def eval_epoch(self) -> None:
|
||||
"""Main entry point for the evaluation loop. Run evaluation on the all validation samples."""
|
||||
self.eval_loader = (
|
||||
self.get_eval_dataloader(
|
||||
self.ap,
|
||||
|
@ -743,7 +762,7 @@ class Trainer:
|
|||
|
||||
def test_run(self) -> None:
|
||||
"""Run test and log the results. Test run must be defined by the model.
|
||||
Model must return figures and audios to be logged by the Tensorboard logger."""
|
||||
Model must return figures and audios to be logged by the Tensorboard."""
|
||||
if hasattr(self.model, "test_run"):
|
||||
if hasattr(self.eval_loader.load_test_samples):
|
||||
samples = self.eval_loader.load_test_samples(1)
|
||||
|
@ -820,11 +839,22 @@ class Trainer:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_apex_available():
|
||||
def _is_apex_available() -> bool:
|
||||
"""Check if Nvidia's APEX is available."""
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]:
|
||||
"""Receive the optimizer from the model if model implements `get_optimizer()` else
|
||||
check the optimizer parameters in the config and try initiating the optimizer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Training model.
|
||||
config (Coqpit): Training configuration.
|
||||
|
||||
Returns:
|
||||
Union[torch.optim.Optimizer, List]: A optimizer or a list of optimizers. GAN models define a list.
|
||||
"""
|
||||
if hasattr(model, "get_optimizer"):
|
||||
optimizer = model.get_optimizer()
|
||||
if optimizer is None:
|
||||
|
@ -835,6 +865,16 @@ class Trainer:
|
|||
|
||||
@staticmethod
|
||||
def get_lr(model: nn.Module, config: Coqpit) -> Union[float, List[float]]:
|
||||
"""Set the initial learning rate by the model if model implements `get_lr()` else try setting the learning rate
|
||||
fromthe config.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Training model.
|
||||
config (Coqpit): Training configuration.
|
||||
|
||||
Returns:
|
||||
Union[float, List[float]]: A single learning rate or a list of learning rates, one for each optimzier.
|
||||
"""
|
||||
lr = None
|
||||
if hasattr(model, "get_lr"):
|
||||
lr = model.get_lr()
|
||||
|
@ -846,6 +886,16 @@ class Trainer:
|
|||
def get_scheduler(
|
||||
model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, List]
|
||||
) -> Union[torch.optim.lr_scheduler._LRScheduler, List]: # pylint: disable=protected-access
|
||||
"""Receive the scheduler from the model if model implements `get_scheduler()` else
|
||||
check the config and try initiating the scheduler.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Training model.
|
||||
config (Coqpit): Training configuration.
|
||||
|
||||
Returns:
|
||||
Union[torch.optim.Optimizer, List]: A scheduler or a list of schedulers, one for each optimizer.
|
||||
"""
|
||||
scheduler = None
|
||||
if hasattr(model, "get_scheduler"):
|
||||
scheduler = model.get_scheduler(optimizer)
|
||||
|
@ -857,6 +907,14 @@ class Trainer:
|
|||
|
||||
@staticmethod
|
||||
def get_criterion(model: nn.Module) -> nn.Module:
|
||||
"""Receive the criterion from the model. Model must implement `get_criterion()`.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Training model.
|
||||
|
||||
Returns:
|
||||
nn.Module: Criterion layer.
|
||||
"""
|
||||
criterion = None
|
||||
criterion = model.get_criterion()
|
||||
return criterion
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# coding: utf-8
|
||||
# adapted from https://github.com/r9y9/tacotron_pytorch
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# adapted from https://github.com/keithito/tacotron
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
|
|
Loading…
Reference in New Issue