mirror of https://github.com/coqui-ai/TTS.git
Docstrings for `Trainer`
This commit is contained in:
parent
65958eaa41
commit
ae6405bb76
|
@ -4,6 +4,7 @@ import glob
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
@ -40,13 +41,21 @@ from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_availa
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
|
|
||||||
|
if platform.system() != "Windows":
|
||||||
|
# https://github.com/pytorch/pytorch/issues/973
|
||||||
|
import resource
|
||||||
|
|
||||||
|
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||||
|
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
|
||||||
|
|
||||||
if is_apex_available():
|
if is_apex_available():
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArgs(Coqpit):
|
class TrainingArgs(Coqpit):
|
||||||
"""Trainer arguments"""
|
"""Trainer arguments to be defined externally. It helps integrating the `Trainer` with the higher level APIs and
|
||||||
|
set the values for distributed training."""
|
||||||
|
|
||||||
continue_path: str = field(
|
continue_path: str = field(
|
||||||
default="",
|
default="",
|
||||||
|
@ -360,7 +369,7 @@ class Trainer:
|
||||||
return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus)
|
return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus)
|
||||||
|
|
||||||
def format_batch(self, batch: List) -> Dict:
|
def format_batch(self, batch: List) -> Dict:
|
||||||
"""Format dataloader ouput and return a batch.
|
"""Format the dataloader output and return a batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch (List): Batch returned by the dataloader.
|
batch (List): Batch returned by the dataloader.
|
||||||
|
@ -633,7 +642,7 @@ class Trainer:
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_epoch(self) -> None:
|
def train_epoch(self) -> None:
|
||||||
"""Main entry point for training. Run training on the whole training samples."""
|
"""Main entry point for the training loop. Run training on the all training samples."""
|
||||||
self.train_loader = self.get_train_dataloader(
|
self.train_loader = self.get_train_dataloader(
|
||||||
self.ap,
|
self.ap,
|
||||||
self.data_train,
|
self.data_train,
|
||||||
|
@ -682,6 +691,15 @@ class Trainer:
|
||||||
return model.eval_step(*input_args)
|
return model.eval_step(*input_args)
|
||||||
|
|
||||||
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
||||||
|
"""Perform a evaluation step on a batch of inputs and log the process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (Dict): Input batch.
|
||||||
|
step (int): Current step number in this epoch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, Dict]: Model outputs and losses.
|
||||||
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs_per_optimizer = None
|
outputs_per_optimizer = None
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
|
@ -708,6 +726,7 @@ class Trainer:
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def eval_epoch(self) -> None:
|
def eval_epoch(self) -> None:
|
||||||
|
"""Main entry point for the evaluation loop. Run evaluation on the all validation samples."""
|
||||||
self.eval_loader = (
|
self.eval_loader = (
|
||||||
self.get_eval_dataloader(
|
self.get_eval_dataloader(
|
||||||
self.ap,
|
self.ap,
|
||||||
|
@ -743,7 +762,7 @@ class Trainer:
|
||||||
|
|
||||||
def test_run(self) -> None:
|
def test_run(self) -> None:
|
||||||
"""Run test and log the results. Test run must be defined by the model.
|
"""Run test and log the results. Test run must be defined by the model.
|
||||||
Model must return figures and audios to be logged by the Tensorboard logger."""
|
Model must return figures and audios to be logged by the Tensorboard."""
|
||||||
if hasattr(self.model, "test_run"):
|
if hasattr(self.model, "test_run"):
|
||||||
if hasattr(self.eval_loader.load_test_samples):
|
if hasattr(self.eval_loader.load_test_samples):
|
||||||
samples = self.eval_loader.load_test_samples(1)
|
samples = self.eval_loader.load_test_samples(1)
|
||||||
|
@ -820,11 +839,22 @@ class Trainer:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_apex_available():
|
def _is_apex_available() -> bool:
|
||||||
|
"""Check if Nvidia's APEX is available."""
|
||||||
return importlib.util.find_spec("apex") is not None
|
return importlib.util.find_spec("apex") is not None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]:
|
def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]:
|
||||||
|
"""Receive the optimizer from the model if model implements `get_optimizer()` else
|
||||||
|
check the optimizer parameters in the config and try initiating the optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Training model.
|
||||||
|
config (Coqpit): Training configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[torch.optim.Optimizer, List]: A optimizer or a list of optimizers. GAN models define a list.
|
||||||
|
"""
|
||||||
if hasattr(model, "get_optimizer"):
|
if hasattr(model, "get_optimizer"):
|
||||||
optimizer = model.get_optimizer()
|
optimizer = model.get_optimizer()
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
|
@ -835,6 +865,16 @@ class Trainer:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_lr(model: nn.Module, config: Coqpit) -> Union[float, List[float]]:
|
def get_lr(model: nn.Module, config: Coqpit) -> Union[float, List[float]]:
|
||||||
|
"""Set the initial learning rate by the model if model implements `get_lr()` else try setting the learning rate
|
||||||
|
fromthe config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Training model.
|
||||||
|
config (Coqpit): Training configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[float, List[float]]: A single learning rate or a list of learning rates, one for each optimzier.
|
||||||
|
"""
|
||||||
lr = None
|
lr = None
|
||||||
if hasattr(model, "get_lr"):
|
if hasattr(model, "get_lr"):
|
||||||
lr = model.get_lr()
|
lr = model.get_lr()
|
||||||
|
@ -846,6 +886,16 @@ class Trainer:
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, List]
|
model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, List]
|
||||||
) -> Union[torch.optim.lr_scheduler._LRScheduler, List]: # pylint: disable=protected-access
|
) -> Union[torch.optim.lr_scheduler._LRScheduler, List]: # pylint: disable=protected-access
|
||||||
|
"""Receive the scheduler from the model if model implements `get_scheduler()` else
|
||||||
|
check the config and try initiating the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Training model.
|
||||||
|
config (Coqpit): Training configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[torch.optim.Optimizer, List]: A scheduler or a list of schedulers, one for each optimizer.
|
||||||
|
"""
|
||||||
scheduler = None
|
scheduler = None
|
||||||
if hasattr(model, "get_scheduler"):
|
if hasattr(model, "get_scheduler"):
|
||||||
scheduler = model.get_scheduler(optimizer)
|
scheduler = model.get_scheduler(optimizer)
|
||||||
|
@ -857,6 +907,14 @@ class Trainer:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_criterion(model: nn.Module) -> nn.Module:
|
def get_criterion(model: nn.Module) -> nn.Module:
|
||||||
|
"""Receive the criterion from the model. Model must implement `get_criterion()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Training model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: Criterion layer.
|
||||||
|
"""
|
||||||
criterion = None
|
criterion = None
|
||||||
criterion = model.get_criterion()
|
criterion = model.get_criterion()
|
||||||
return criterion
|
return criterion
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
# adapted from https://github.com/r9y9/tacotron_pytorch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
# adapted from https://github.com/keithito/tacotron
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
|
Loading…
Reference in New Issue