Refactor `trainer.py` for v2

This commit is contained in:
Eren Gölge 2021-09-30 14:16:34 +00:00
parent 7f388f26e3
commit 8ada870a57
5 changed files with 388 additions and 291 deletions

View File

@ -4,16 +4,14 @@ import importlib
import multiprocessing import multiprocessing
import os import os
import platform import platform
import re
import sys import sys
import time import time
import traceback import traceback
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union from inspect import signature
from urllib.parse import urlparse from typing import Callable, Dict, List, Tuple, Union
import fsspec
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from coqpit import Coqpit from coqpit import Coqpit
@ -21,11 +19,7 @@ from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.config import load_config, register_config from TTS.stt.datasets.tokenizer import Tokenizer
from TTS.tts.datasets import load_meta_data
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.callbacks import TrainerCallback from TTS.utils.callbacks import TrainerCallback
from TTS.utils.distribute import init_distributed from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import ( from TTS.utils.generic_utils import (
@ -39,9 +33,13 @@ from TTS.utils.generic_utils import (
) )
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.utils.trainer_utils import (
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data get_last_checkpoint,
from TTS.vocoder.models import setup_model as setup_vocoder_model get_optimizer,
get_scheduler,
is_apex_available,
setup_torch_training_env,
)
multiprocessing.set_start_method("fork") multiprocessing.set_start_method("fork")
@ -80,6 +78,9 @@ class TrainingArgs(Coqpit):
"help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used" "help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
}, },
) )
skip_train_epoch: bool = field(
default=False, metadata={"help": "Run only evaluation iteration. Useful for debugging."}
)
config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
@ -98,7 +99,14 @@ class Trainer:
c_logger: ConsoleLogger = None, c_logger: ConsoleLogger = None,
dashboard_logger: Union[TensorboardLogger, WandbLogger] = None, dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
model: nn.Module = None, model: nn.Module = None,
get_model: Callable = None,
get_data_samples: Callable = None,
train_samples: List = None,
eval_samples: List = None,
tokenizer: Tokenizer = None,
cudnn_benchmark: bool = False, cudnn_benchmark: bool = False,
training_assets: Dict = {},
parse_command_line_args: bool = True,
) -> None: ) -> None:
"""Simple yet powerful 🐸💬 TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models """Simple yet powerful 🐸💬 TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models
or easily be customized. or easily be customized.
@ -127,24 +135,44 @@ class Trainer:
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer` model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
initializes a model from the provided config. Defaults to None. initializes a model from the provided config. Defaults to None.
get_model (Callable):
A function that returns a model. It is used to initialize the model when `model` is not provided.
It either takes the config as the only argument or does not take any argument.
Defaults to None
get_data_samples (Callable):
A function that returns a list of training and evaluation samples. Used if `train_samples` and
`eval_samples` are None. Defaults to None.
train_samples (List):
A list of training samples used by the model's `get_data_loader` to init the `dataset` and the
`data_loader`. Defaults to None.
eval_samples (List):
A list of evaluation samples used by the model's `get_data_loader` to init the `dataset` and the
`data_loader`. Defaults to None.
cudnn_benchmark (bool): enable/disable PyTorch cudnn benchmarking. It is better to disable if the model input cudnn_benchmark (bool): enable/disable PyTorch cudnn benchmarking. It is better to disable if the model input
length is changing batch to batch along the training. length is changing batch to batch along the training.
training_assets (Dict):
A dictionary of assets to be used at training and passed to the model's ```train_log(), eval_log(), get_data_loader()```
during training. It can include `AudioProcessor` or/and `Tokenizer`. Defaults to {}.
parse_command_line_args (bool):
If true, parse command-line arguments and update `TrainingArgs` and model `config` values. Set it
to false if you parse the arguments yourself. Defaults to True.
Examples: Examples:
Running trainer on a model. Running trainer with HifiGAN model.
>>> args = TrainingArgs(...) >>> args = TrainingArgs(...)
>>> config = HifiganConfig(...) >>> config = HifiganConfig(...)
>>> model = GANModel(config) >>> model = GANModel(config)
>>> trainer = Trainer(args, config, output_path, model=model) >>> ap = AudioProcessor(**config.audio)
>>> trainer.fit() >>> assets = {"audio_processor": ap}
>>> trainer = Trainer(args, config, output_path, model=model, training_assets=assets)
Running trainer on a config.
>>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
>>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
>>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
>>> trainer.fit() >>> trainer.fit()
TODO: TODO:
@ -154,20 +182,33 @@ class Trainer:
- Profiler integration. - Profiler integration.
- Overfitting to a batch. - Overfitting to a batch.
- TPU training - TPU training
- NOTE: Consider moving `training_assets` to the model implementation.
""" """
if parse_command_line_args:
# parse command-line arguments for TrainingArgs()
args, coqpit_overrides = self.parse_argv(args)
if config is None: # get ready for training and parse command-line arguments for the model config
# parse config from console arguments config = self.init_training(args, coqpit_overrides, config)
config, output_path, _, c_logger, dashboard_logger = process_args(args)
# define the experiment path and create the folder
output_path = get_experiment_folder_path(config.output_path, config.run_name)
os.makedirs(output_path, exist_ok=True)
# copy training assets to the output folder
copy_model_files(config, output_path, new_fields=None)
# init class members
self.args = args self.args = args
self.config = config self.config = config
self.output_path = output_path self.output_path = output_path
self.config.output_log_path = output_path self.config.output_log_path = output_path
self.training_assets = training_assets
# setup logging # setup logging
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file) self._setup_logger_config(log_file)
time.sleep(1.0) # wait for the logger to be ready
# set and initialize Pytorch runtime # set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp) self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
@ -196,33 +237,30 @@ class Trainer:
self.use_apex = self._is_apex_available() self.use_apex = self._is_apex_available()
self.use_amp_scaler = self.config.mixed_precision and self.use_cuda self.use_amp_scaler = self.config.mixed_precision and self.use_cuda
# init audio processor # init tokenizer
self.ap = AudioProcessor(**self.config.audio.to_dict()) self.tokenizer = tokenizer
# load data samples # load data samples
# TODO: refactor this if train_samples is None and get_data_samples is None:
if "datasets" in self.config: raise ValueError("[!] `train_samples` and `get_data_samples` cannot both be None.")
# load data for `tts` models if train_samples is not None:
self.data_train, self.data_eval = load_meta_data(self.config.datasets) self.train_samples = train_samples
elif self.config.feature_path is not None: self.eval_samples = eval_samples
# load pre-comnputed features for `vocoder`models
print(f" > Loading features from: {self.config.feature_path}")
self.data_eval, self.data_train = load_wav_feat_data(
self.config.data_path, self.config.feature_path, self.config.eval_split_size
)
else: else:
# load data for `vocoder`models self.train_samples, self.eval_samples = self.run_get_data_samples(config, get_data_samples)
self.data_eval, self.data_train = load_wav_data(self.config.data_path, self.config.eval_split_size)
# init TTS model # init TTS model
if model is None and get_model is None:
raise ValueError("[!] `model` and `get_model` cannot both be None.")
if model is not None: if model is not None:
self.model = model self.model = model
else: else:
self.model = self.get_model(self.config) self.run_get_model(self.config, get_model)
# TODO: out!
# init multispeaker settings of the model # init multispeaker settings of the model
if hasattr(self.model, "init_multispeaker"): if hasattr(self.model, "init_multispeaker"):
self.model.init_multispeaker(self.config, self.data_train + self.data_eval) self.model.init_multispeaker(self.config, self.train_samples + self.eval_samples)
# setup criterion # setup criterion
self.criterion = self.get_criterion(self.model) self.criterion = self.get_criterion(self.model)
@ -247,7 +285,7 @@ class Trainer:
# setup optimizer # setup optimizer
self.optimizer = self.get_optimizer(self.model, self.config) self.optimizer = self.get_optimizer(self.model, self.config)
# callback # CALLBACK
self.callbacks = TrainerCallback(self) self.callbacks = TrainerCallback(self)
self.callbacks.on_init_start() self.callbacks.on_init_start()
@ -280,7 +318,7 @@ class Trainer:
else: else:
self.scheduler.last_epoch = self.restore_step self.scheduler.last_epoch = self.restore_step
# DISTRUBUTED # DISTRIBUTED
if self.num_gpus > 1: if self.num_gpus > 1:
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
@ -291,8 +329,54 @@ class Trainer:
self.callbacks.on_init_end() self.callbacks.on_init_end()
@staticmethod @staticmethod
def get_model(config: Coqpit) -> nn.Module: def parse_argv(args: Union[Coqpit, List]):
"""Initialize model from config. """Parse command line arguments to init or override `TrainingArgs()`."""
if isinstance(args, Coqpit):
parser = args.init_argparse(arg_prefix="")
else:
train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="")
training_args, coqpit_overrides = parser.parse_known_args()
args.parse_args(training_args)
return args, coqpit_overrides
def init_training(self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None):
"""Initialize training and update model configs from command line arguments.
Args:
args (argparse.Namespace or dict like): Parsed input arguments.
config_overrides (argparse.Namespace or dict like): Parsed config overriding arguments.
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
Returns:
c (TTS.utils.io.AttrDict): Config paramaters.
"""
# set arguments for continuing training
if args.continue_path:
experiment_path = args.continue_path
args.config_path = os.path.join(args.continue_path, "config.json")
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model
# override config values from command-line args
# TODO: Maybe it is better to do it outside
if len(coqpit_overrides) > 0:
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
experiment_path = args.continue_path
# update the config.json fields and copy it to the output folder
if args.rank == 0:
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_model_files(config, experiment_path, new_fields)
return config
@staticmethod
def run_get_model(config: Coqpit, get_model: Callable) -> nn.Module:
"""Run the `get_model` function and return the model.
Args: Args:
config (Coqpit): Model config. config (Coqpit): Model config.
@ -300,12 +384,23 @@ class Trainer:
Returns: Returns:
nn.Module: initialized model. nn.Module: initialized model.
""" """
try: if len(signature(get_model).sig.parameters) == 1:
model = setup_vocoder_model(config) model = get_model(config)
except ModuleNotFoundError: else:
model = setup_tts_model(config) model = get_model()
return model return model
@staticmethod
def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module:
if isinstance(get_data_samples, Callable):
if len(signature(get_data_samples).sig.parameters) == 1:
train_samples, eval_samples = get_data_samples(config)
else:
train_samples, eval_samples = get_data_samples()
return train_samples, eval_samples
else:
return None, None
def restore_model( def restore_model(
self, self,
config: Coqpit, config: Coqpit,
@ -366,11 +461,15 @@ class Trainer:
torch.cuda.empty_cache() torch.cuda.empty_cache()
return model, optimizer, scaler, restore_step return model, optimizer, scaler, restore_step
#########################
# DATA LOADING FUNCTIONS
#########################
def _get_loader( def _get_loader(
self, self,
model: nn.Module, model: nn.Module,
config: Coqpit, config: Coqpit,
ap: AudioProcessor, assets: Dict,
is_eval: bool, is_eval: bool,
data_items: List, data_items: List,
verbose: bool, verbose: bool,
@ -379,14 +478,14 @@ class Trainer:
if num_gpus > 1: if num_gpus > 1:
if hasattr(model.module, "get_data_loader"): if hasattr(model.module, "get_data_loader"):
loader = model.module.get_data_loader( loader = model.module.get_data_loader(
config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank config, assets, is_eval, data_items, verbose, num_gpus, self.args.rank
) )
else: else:
if hasattr(model, "get_data_loader"): if hasattr(model, "get_data_loader"):
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus) loader = model.get_data_loader(config, assets, is_eval, data_items, verbose, num_gpus)
return loader return loader
def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader: def get_train_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader:
"""Initialize and return a training data loader. """Initialize and return a training data loader.
Args: Args:
@ -397,10 +496,10 @@ class Trainer:
Returns: Returns:
DataLoader: Initialized training data loader. DataLoader: Initialized training data loader.
""" """
return self._get_loader(self.model, self.config, ap, False, data_items, verbose, self.num_gpus) return self._get_loader(self.model, self.config, training_assets, False, data_items, verbose, self.num_gpus)
def get_eval_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader: def get_eval_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader:
return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus) return self._get_loader(self.model, self.config, training_assets, True, data_items, verbose, self.num_gpus)
def format_batch(self, batch: List) -> Dict: def format_batch(self, batch: List) -> Dict:
"""Format the dataloader output and return a batch. """Format the dataloader output and return a batch.
@ -420,6 +519,10 @@ class Trainer:
batch[k] = to_cuda(v) batch[k] = to_cuda(v)
return batch return batch
######################
# TRAIN FUNCTIONS
######################
@staticmethod @staticmethod
def master_params(optimizer: torch.optim.Optimizer): def master_params(optimizer: torch.optim.Optimizer):
"""Generator over parameters owned by the optimizer. """Generator over parameters owned by the optimizer.
@ -567,24 +670,6 @@ class Trainer:
loss_dict["grad_norm"] = grad_norm loss_dict["grad_norm"] = grad_norm
return outputs, loss_dict, step_time return outputs, loss_dict, step_time
@staticmethod
def _detach_loss_dict(loss_dict: Dict) -> Dict:
"""Detach loss values from autograp.
Args:
loss_dict (Dict): losses.
Returns:
Dict: losses detached from autograph.
"""
loss_dict_detached = {}
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_detached[key] = value
else:
loss_dict_detached[key] = value.item()
return loss_dict_detached
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
"""Perform a training step on a batch of inputs and log the process. """Perform a training step on a batch of inputs and log the process.
@ -700,15 +785,14 @@ class Trainer:
self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases) self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
# training visualizations # training visualizations
figures, audios = None, None
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"): if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
figures, audios = self.model.module.train_log(self.ap, batch, outputs) self.model.module.train_log(
batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
)
elif hasattr(self.model, "train_log"): elif hasattr(self.model, "train_log"):
figures, audios = self.model.train_log(self.ap, batch, outputs) self.model.train_log(
if figures is not None: batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
self.dashboard_logger.train_figures(self.total_steps_done, figures) )
if audios is not None:
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.dashboard_logger.flush() self.dashboard_logger.flush()
@ -718,11 +802,13 @@ class Trainer:
def train_epoch(self) -> None: def train_epoch(self) -> None:
"""Main entry point for the training loop. Run training on the all training samples.""" """Main entry point for the training loop. Run training on the all training samples."""
# initialize the data loader
self.train_loader = self.get_train_dataloader( self.train_loader = self.get_train_dataloader(
self.ap, self.training_assets,
self.data_train, self.train_samples,
verbose=True, verbose=True,
) )
# set model to training mode
if self.num_gpus > 1: if self.num_gpus > 1:
self.model.module.train() self.model.module.train()
else: else:
@ -734,11 +820,12 @@ class Trainer:
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_train_start() self.c_logger.print_train_start()
loader_start_time = time.time() loader_start_time = time.time()
# iterate over the training samples
for cur_step, batch in enumerate(self.train_loader): for cur_step, batch in enumerate(self.train_loader):
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
loader_start_time = time.time() loader_start_time = time.time()
epoch_time = time.time() - epoch_start_time epoch_time = time.time() - epoch_start_time
# Plot self.epochs_done Stats # plot self.epochs_done Stats
if self.args.rank == 0: if self.args.rank == 0:
epoch_stats = {"epoch_time": epoch_time} epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(self.keep_avg_train.avg_values) epoch_stats.update(self.keep_avg_train.avg_values)
@ -754,6 +841,10 @@ class Trainer:
else: else:
self.scheduler.step() self.scheduler.step()
#######################
# EVAL FUNCTIONS
#######################
@staticmethod @staticmethod
def _model_eval_step( def _model_eval_step(
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
@ -819,8 +910,8 @@ class Trainer:
"""Main entry point for the evaluation loop. Run evaluation on the all validation samples.""" """Main entry point for the evaluation loop. Run evaluation on the all validation samples."""
self.eval_loader = ( self.eval_loader = (
self.get_eval_dataloader( self.get_eval_dataloader(
self.ap, self.training_assets,
self.data_eval, self.eval_samples,
verbose=True, verbose=True,
) )
if self.config.run_eval if self.config.run_eval
@ -840,15 +931,12 @@ class Trainer:
loader_start_time = time.time() loader_start_time = time.time()
# plot epoch stats, artifacts and figures # plot epoch stats, artifacts and figures
if self.args.rank == 0: if self.args.rank == 0:
figures, audios = None, None
if hasattr(self.model, "module") and hasattr(self.model.module, "eval_log"): if hasattr(self.model, "module") and hasattr(self.model.module, "eval_log"):
figures, audios = self.model.module.eval_log(self.ap, batch, outputs) self.model.module.eval_log(
batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
)
elif hasattr(self.model, "eval_log"): elif hasattr(self.model, "eval_log"):
figures, audios = self.model.eval_log(self.ap, batch, outputs) self.model.eval_log(batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done)
if figures is not None:
self.dashboard_logger.eval_figures(self.total_steps_done, figures)
if audios is not None:
self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values) self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
def test_run(self) -> None: def test_run(self) -> None:
@ -857,22 +945,22 @@ class Trainer:
if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")): if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
if self.eval_loader is None: if self.eval_loader is None:
self.eval_loader = self.get_eval_dataloader( self.eval_loader = self.get_eval_dataloader(
self.ap, self.training_assets,
self.data_eval, self.eval_samples,
verbose=True, verbose=True,
) )
if hasattr(self.eval_loader.dataset, "load_test_samples"): if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1) samples = self.eval_loader.dataset.load_test_samples(1)
if self.num_gpus > 1: if self.num_gpus > 1:
figures, audios = self.model.module.test_run(self.ap, samples, None) figures, audios = self.model.module.test_run(self.training_assets, samples, None)
else: else:
figures, audios = self.model.test_run(self.ap, samples, None) figures, audios = self.model.test_run(self.training_assets, samples, None)
else: else:
if self.num_gpus > 1: if self.num_gpus > 1:
figures, audios = self.model.module.test_run(self.ap) figures, audios = self.model.module.test_run(self.training_assets)
else: else:
figures, audios = self.model.test_run(self.ap) figures, audios = self.model.test_run(self.training_assets)
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.dashboard_logger.test_figures(self.total_steps_done, figures) self.dashboard_logger.test_figures(self.total_steps_done, figures)
@ -886,6 +974,10 @@ class Trainer:
self.best_loss = ch["model_loss"] self.best_loss = ch["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.") print(f" > Starting with loaded last best loss {self.best_loss}.")
###################################
# FIT FUNCTIONS
###################################
def _fit(self) -> None: def _fit(self) -> None:
"""🏃 train -> evaluate -> test for the number of epochs.""" """🏃 train -> evaluate -> test for the number of epochs."""
self._restore_best_loss() self._restore_best_loss()
@ -901,6 +993,7 @@ class Trainer:
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
self.epochs_done = epoch self.epochs_done = epoch
self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path) self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
if not self.args.skip_train_epoch:
self.train_epoch() self.train_epoch()
if self.config.run_eval: if self.config.run_eval:
self.eval_epoch() self.eval_epoch()
@ -939,24 +1032,6 @@ class Trainer:
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
"""Pick the target loss to compare models"""
target_avg_loss = None
# return if target loss defined in the model config
if "target_loss" in self.config and self.config.target_loss:
return keep_avg_target[f"avg_{self.config.target_loss}"]
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
if isinstance(self.optimizer, list):
target_avg_loss = 0
for idx in range(len(self.optimizer)):
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
target_avg_loss /= len(self.optimizer)
else:
target_avg_loss = keep_avg_target["avg_loss"]
return target_avg_loss
def save_best_model(self) -> None: def save_best_model(self) -> None:
"""Save the best model. It only saves if the current target loss is smaller then the previous.""" """Save the best model. It only saves if the current target loss is smaller then the previous."""
@ -978,35 +1053,9 @@ class Trainer:
keep_after=self.config.keep_after, keep_after=self.config.keep_after,
) )
def _setup_logger_config(self, log_file: str) -> None: #####################
"""Write log strings to a file and print logs to the terminal. # GET FUNCTIONS
TODO: Causes formatting issues in pdb debugging.""" #####################
class Logger(object):
def __init__(self, print_to_terminal=True):
self.print_to_terminal = print_to_terminal
self.terminal = sys.stdout
self.log_file = log_file
def write(self, message):
if self.print_to_terminal:
self.terminal.write(message)
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(message)
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
# don't let processes rank > 0 write to the terminal
sys.stdout = Logger(self.args.rank == 0)
@staticmethod
def _is_apex_available() -> bool:
"""Check if Nvidia's APEX is available."""
return importlib.util.find_spec("apex") is not None
@staticmethod @staticmethod
def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]: def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]:
@ -1084,154 +1133,72 @@ class Trainer:
criterion = model.get_criterion() criterion = model.get_criterion()
return criterion return criterion
####################
# HELPER FUNCTIONS
####################
def getarguments(): @staticmethod
train_config = TrainingArgs() def _detach_loss_dict(loss_dict: Dict) -> Dict:
parser = train_config.init_argparse(arg_prefix="") """Detach loss values from autograp.
return parser
def get_last_checkpoint(path: str) -> Tuple[str, str]:
"""Get latest checkpoint or/and best model in path.
It is based on globbing for `*.pth.tar` and the RegEx
`(checkpoint|best_model)_([0-9]+)`.
Args: Args:
path: Path to files to be compared. loss_dict (Dict): losses.
Raises:
ValueError: If no checkpoint or best_model files are found.
Returns: Returns:
Path to the last checkpoint Dict: losses detached from autograph.
Path to best checkpoint
""" """
fs = fsspec.get_mapper(path).fs loss_dict_detached = {}
file_names = fs.glob(os.path.join(path, "*.pth.tar")) for key, value in loss_dict.items():
scheme = urlparse(path).scheme if isinstance(value, (int, float)):
if scheme: # scheme is not preserved in fs.glob, add it back loss_dict_detached[key] = value
file_names = [scheme + "://" + file_name for file_name in file_names]
last_models = {}
last_model_nums = {}
for key in ["checkpoint", "best_model"]:
last_model_num = None
last_model = None
# pass all the checkpoint files and find
# the one with the largest model number suffix.
for file_name in file_names:
match = re.search(f"{key}_([0-9]+)", file_name)
if match is not None:
model_num = int(match.groups()[0])
if last_model_num is None or model_num > last_model_num:
last_model_num = model_num
last_model = file_name
# if there is no checkpoint found above
# find the checkpoint with the latest
# modification date.
key_file_names = [fn for fn in file_names if key in fn]
if last_model is None and len(key_file_names) > 0:
last_model = max(key_file_names, key=os.path.getctime)
last_model_num = load_fsspec(last_model)["step"]
if last_model is not None:
last_models[key] = last_model
last_model_nums[key] = last_model_num
# check what models were found
if not last_models:
raise ValueError(f"No models found in continue path {path}!")
if "checkpoint" not in last_models: # no checkpoint just best model
last_models["checkpoint"] = last_models["best_model"]
elif "best_model" not in last_models: # no best model
# this shouldn't happen, but let's handle it just in case
last_models["best_model"] = last_models["checkpoint"]
# finally check if last best model is more recent than checkpoint
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
last_models["checkpoint"] = last_models["best_model"]
return last_models["checkpoint"], last_models["best_model"]
def process_args(args, config=None):
"""Process parsed comand line arguments and initialize the config if not provided.
Args:
args (argparse.Namespace or dict like): Parsed input arguments.
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
Returns:
c (TTS.utils.io.AttrDict): Config paramaters.
out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
logging to the console.
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
TODO:
- Interactive config definition.
"""
if isinstance(args, tuple):
args, coqpit_overrides = args
if args.continue_path:
# continue a previous training from its output folder
experiment_path = args.continue_path
args.config_path = os.path.join(args.continue_path, "config.json")
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model
# init config if not already defined
if config is None:
if args.config_path:
# init from a file
config = load_config(args.config_path)
else: else:
# init from console args loss_dict_detached[key] = value.detach()
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel return loss_dict_detached
config_base = BaseTrainingConfig() def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
config_base.parse_known_args(coqpit_overrides) """Pick the target loss to compare models"""
config = register_config(config_base.model)() target_avg_loss = None
# override values from command-line args
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
experiment_path = args.continue_path
if not experiment_path:
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
audio_path = os.path.join(experiment_path, "test_audios")
config.output_log_path = experiment_path
# setup rank 0 process in distributed training
dashboard_logger = None
if args.rank == 0:
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
# if model characters are not set in the config file
# save the default set to the config file for future
# compatibility.
if config.has("characters") and config.characters is None:
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
dashboard_logger = init_dashboard_logger(config)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, dashboard_logger
# return if target loss defined in the model config
if "target_loss" in self.config and self.config.target_loss:
return keep_avg_target[f"avg_{self.config.target_loss}"]
def init_arguments(): # take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
train_config = TrainingArgs() if isinstance(self.optimizer, list):
parser = train_config.init_argparse(arg_prefix="") target_avg_loss = 0
return parser for idx in range(len(self.optimizer)):
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
target_avg_loss /= len(self.optimizer)
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
"""Initialization of a training run."""
if isinstance(argv, Coqpit):
parser = argv.init_argparse(arg_prefix="")
else: else:
parser = init_arguments() target_avg_loss = keep_avg_target["avg_loss"]
args = parser.parse_known_args() return target_avg_loss
config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger def _setup_logger_config(self, log_file: str) -> None:
"""Write log strings to a file and print logs to the terminal.
TODO: Causes formatting issues in pdb debugging."""
class Logger(object):
def __init__(self, print_to_terminal=True):
self.print_to_terminal = print_to_terminal
self.terminal = sys.stdout
self.log_file = log_file
def write(self, message):
if self.print_to_terminal:
self.terminal.write(message)
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(message)
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
# don't let processes rank > 0 write to the terminal
sys.stdout = Logger(self.args.rank == 0)
@staticmethod
def _is_apex_available() -> bool:
"""Check if Nvidia's APEX is available."""
return importlib.util.find_spec("apex") is not None

View File

@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel from TTS.model import BaseModel
from TTS.tts.datasets import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols from TTS.tts.utils.text import make_symbols
@ -32,6 +32,30 @@ class BaseTTS(BaseModel):
- 1D tensors `batch x 1` - 1D tensors `batch x 1`
""" """
def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type.
If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args
If the config is for the model with a name like "*Args", then we assign the directly.
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
if "characters" in config:
_, self.config, num_chars = self.get_characters(config)
self.config.num_chars = num_chars
if hasattr(self.config, "model_args"):
config.model_args.num_chars = num_chars
self.args = self.config.model_args
else:
self.config = config
self.args = config.model_args
elif "Args" in config.__class__.__name__:
self.args = config
else:
raise ValueError("config must be either a *Config or *Args")
@staticmethod @staticmethod
def get_characters(config: Coqpit) -> str: def get_characters(config: Coqpit) -> str:
# TODO: implement CharacterProcessor # TODO: implement CharacterProcessor
@ -169,7 +193,7 @@ class BaseTTS(BaseModel):
def get_data_loader( def get_data_loader(
self, self,
config: Coqpit, config: Coqpit,
ap: AudioProcessor, assets: Dict,
is_eval: bool, is_eval: bool,
data_items: List, data_items: List,
verbose: bool, verbose: bool,
@ -179,6 +203,8 @@ class BaseTTS(BaseModel):
if is_eval and not config.run_eval: if is_eval and not config.run_eval:
loader = None loader = None
else: else:
ap = assets["audio_processor"]
# setup multi-speaker attributes # setup multi-speaker attributes
if hasattr(self, "speaker_manager"): if hasattr(self, "speaker_manager"):
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
@ -280,14 +306,18 @@ class BaseTTS(BaseModel):
) )
return loader return loader
def test_run(self, ap) -> Tuple[Dict, Dict]: def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`. """Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour. You can override this for a different behaviour.
Args:
assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`.
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
ap = assets["audio_processor"]
print(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}

View File

@ -1,8 +1,13 @@
import importlib import importlib
import os
import re
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from urllib.parse import urlparse
import fsspec
import torch import torch
from TTS.utils.io import load_fsspec
from TTS.utils.training import NoamLR from TTS.utils.training import NoamLR
@ -80,3 +85,66 @@ def get_optimizer(
if model is not None: if model is not None:
parameters = model.parameters() parameters = model.parameters()
return optimizer(parameters, lr=lr, **optimizer_params) return optimizer(parameters, lr=lr, **optimizer_params)
def get_last_checkpoint(path: str) -> Tuple[str, str]:
"""Get latest checkpoint or/and best model in path.
It is based on globbing for `*.pth.tar` and the RegEx
`(checkpoint|best_model)_([0-9]+)`.
Args:
path: Path to files to be compared.
Raises:
ValueError: If no checkpoint or best_model files are found.
Returns:
Path to the last checkpoint
Path to best checkpoint
"""
fs = fsspec.get_mapper(path).fs
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
scheme = urlparse(path).scheme
if scheme: # scheme is not preserved in fs.glob, add it back
file_names = [scheme + "://" + file_name for file_name in file_names]
last_models = {}
last_model_nums = {}
for key in ["checkpoint", "best_model"]:
last_model_num = None
last_model = None
# pass all the checkpoint files and find
# the one with the largest model number suffix.
for file_name in file_names:
match = re.search(f"{key}_([0-9]+)", file_name)
if match is not None:
model_num = int(match.groups()[0])
if last_model_num is None or model_num > last_model_num:
last_model_num = model_num
last_model = file_name
# if there is no checkpoint found above
# find the checkpoint with the latest
# modification date.
key_file_names = [fn for fn in file_names if key in fn]
if last_model is None and len(key_file_names) > 0:
last_model = max(key_file_names, key=os.path.getctime)
last_model_num = load_fsspec(last_model)["step"]
if last_model is not None:
last_models[key] = last_model
last_model_nums[key] = last_model_num
# check what models were found
if not last_models:
raise ValueError(f"No models found in continue path {path}!")
if "checkpoint" not in last_models: # no checkpoint just best model
last_models["checkpoint"] = last_models["best_model"]
elif "best_model" not in last_models: # no best model
# this shouldn't happen, but let's handle it just in case
last_models["best_model"] = last_models["checkpoint"]
# finally check if last best model is more recent than checkpoint
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
last_models["checkpoint"] = last_models["best_model"]
return last_models["checkpoint"], last_models["best_model"]

View File

@ -1,3 +1,5 @@
from coqpit import Coqpit
from TTS.model import BaseModel from TTS.model import BaseModel
# pylint: skip-file # pylint: skip-file
@ -16,5 +18,35 @@ class BaseVocoder(BaseModel):
- 1D tensors `batch x 1` - 1D tensors `batch x 1`
""" """
def __init__(self): def __init__(self, config):
super().__init__() super().__init__(config)
def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type.
If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args
If the config is for the model with a name like "*Args", then we assign the directly.
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
if "characters" in config:
_, self.config, num_chars = self.get_characters(config)
self.config.num_chars = num_chars
if hasattr(self.config, "model_args"):
config.model_args.num_chars = num_chars
if "model_args" in config:
self.args = self.config.model_args
# This is for backward compatibility
if "model_params" in config:
self.args = self.config.model_params
else:
self.config = config
if "model_args" in config:
self.args = self.config.model_args
# This is for backward compatibility
if "model_params" in config:
self.args = self.config.model_params
else:
raise ValueError("config must be either a *Config or *Args")