mirror of https://github.com/coqui-ai/TTS.git
Refactor `trainer.py` for v2
This commit is contained in:
parent
7f388f26e3
commit
8ada870a57
537
TTS/trainer.py
537
TTS/trainer.py
|
@ -4,16 +4,14 @@ import importlib
|
|||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
from inspect import signature
|
||||
from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coqpit import Coqpit
|
||||
|
@ -21,11 +19,7 @@ from torch import nn
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.stt.datasets.tokenizer import Tokenizer
|
||||
from TTS.utils.callbacks import TrainerCallback
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import (
|
||||
|
@ -39,9 +33,13 @@ from TTS.utils.generic_utils import (
|
|||
)
|
||||
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
from TTS.utils.trainer_utils import (
|
||||
get_last_checkpoint,
|
||||
get_optimizer,
|
||||
get_scheduler,
|
||||
is_apex_available,
|
||||
setup_torch_training_env,
|
||||
)
|
||||
|
||||
multiprocessing.set_start_method("fork")
|
||||
|
||||
|
@ -80,6 +78,9 @@ class TrainingArgs(Coqpit):
|
|||
"help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
|
||||
},
|
||||
)
|
||||
skip_train_epoch: bool = field(
|
||||
default=False, metadata={"help": "Run only evaluation iteration. Useful for debugging."}
|
||||
)
|
||||
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
||||
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
|
||||
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
|
||||
|
@ -98,7 +99,14 @@ class Trainer:
|
|||
c_logger: ConsoleLogger = None,
|
||||
dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
|
||||
model: nn.Module = None,
|
||||
get_model: Callable = None,
|
||||
get_data_samples: Callable = None,
|
||||
train_samples: List = None,
|
||||
eval_samples: List = None,
|
||||
tokenizer: Tokenizer = None,
|
||||
cudnn_benchmark: bool = False,
|
||||
training_assets: Dict = {},
|
||||
parse_command_line_args: bool = True,
|
||||
) -> None:
|
||||
"""Simple yet powerful 🐸💬 TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models
|
||||
or easily be customized.
|
||||
|
@ -127,24 +135,44 @@ class Trainer:
|
|||
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
|
||||
initializes a model from the provided config. Defaults to None.
|
||||
|
||||
get_model (Callable):
|
||||
A function that returns a model. It is used to initialize the model when `model` is not provided.
|
||||
It either takes the config as the only argument or does not take any argument.
|
||||
Defaults to None
|
||||
|
||||
get_data_samples (Callable):
|
||||
A function that returns a list of training and evaluation samples. Used if `train_samples` and
|
||||
`eval_samples` are None. Defaults to None.
|
||||
|
||||
train_samples (List):
|
||||
A list of training samples used by the model's `get_data_loader` to init the `dataset` and the
|
||||
`data_loader`. Defaults to None.
|
||||
|
||||
eval_samples (List):
|
||||
A list of evaluation samples used by the model's `get_data_loader` to init the `dataset` and the
|
||||
`data_loader`. Defaults to None.
|
||||
|
||||
cudnn_benchmark (bool): enable/disable PyTorch cudnn benchmarking. It is better to disable if the model input
|
||||
length is changing batch to batch along the training.
|
||||
|
||||
training_assets (Dict):
|
||||
A dictionary of assets to be used at training and passed to the model's ```train_log(), eval_log(), get_data_loader()```
|
||||
during training. It can include `AudioProcessor` or/and `Tokenizer`. Defaults to {}.
|
||||
|
||||
parse_command_line_args (bool):
|
||||
If true, parse command-line arguments and update `TrainingArgs` and model `config` values. Set it
|
||||
to false if you parse the arguments yourself. Defaults to True.
|
||||
|
||||
Examples:
|
||||
|
||||
Running trainer on a model.
|
||||
Running trainer with HifiGAN model.
|
||||
|
||||
>>> args = TrainingArgs(...)
|
||||
>>> config = HifiganConfig(...)
|
||||
>>> model = GANModel(config)
|
||||
>>> trainer = Trainer(args, config, output_path, model=model)
|
||||
>>> trainer.fit()
|
||||
|
||||
Running trainer on a config.
|
||||
|
||||
>>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
|
||||
>>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
>>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
>>> ap = AudioProcessor(**config.audio)
|
||||
>>> assets = {"audio_processor": ap}
|
||||
>>> trainer = Trainer(args, config, output_path, model=model, training_assets=assets)
|
||||
>>> trainer.fit()
|
||||
|
||||
TODO:
|
||||
|
@ -154,20 +182,33 @@ class Trainer:
|
|||
- Profiler integration.
|
||||
- Overfitting to a batch.
|
||||
- TPU training
|
||||
- NOTE: Consider moving `training_assets` to the model implementation.
|
||||
"""
|
||||
if parse_command_line_args:
|
||||
# parse command-line arguments for TrainingArgs()
|
||||
args, coqpit_overrides = self.parse_argv(args)
|
||||
|
||||
if config is None:
|
||||
# parse config from console arguments
|
||||
config, output_path, _, c_logger, dashboard_logger = process_args(args)
|
||||
# get ready for training and parse command-line arguments for the model config
|
||||
config = self.init_training(args, coqpit_overrides, config)
|
||||
|
||||
# define the experiment path and create the folder
|
||||
output_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# copy training assets to the output folder
|
||||
copy_model_files(config, output_path, new_fields=None)
|
||||
|
||||
# init class members
|
||||
self.args = args
|
||||
self.config = config
|
||||
self.output_path = output_path
|
||||
self.config.output_log_path = output_path
|
||||
self.training_assets = training_assets
|
||||
|
||||
# setup logging
|
||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||
self._setup_logger_config(log_file)
|
||||
time.sleep(1.0) # wait for the logger to be ready
|
||||
|
||||
# set and initialize Pytorch runtime
|
||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
|
||||
|
@ -196,33 +237,30 @@ class Trainer:
|
|||
self.use_apex = self._is_apex_available()
|
||||
self.use_amp_scaler = self.config.mixed_precision and self.use_cuda
|
||||
|
||||
# init audio processor
|
||||
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
||||
# init tokenizer
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# load data samples
|
||||
# TODO: refactor this
|
||||
if "datasets" in self.config:
|
||||
# load data for `tts` models
|
||||
self.data_train, self.data_eval = load_meta_data(self.config.datasets)
|
||||
elif self.config.feature_path is not None:
|
||||
# load pre-comnputed features for `vocoder`models
|
||||
print(f" > Loading features from: {self.config.feature_path}")
|
||||
self.data_eval, self.data_train = load_wav_feat_data(
|
||||
self.config.data_path, self.config.feature_path, self.config.eval_split_size
|
||||
)
|
||||
if train_samples is None and get_data_samples is None:
|
||||
raise ValueError("[!] `train_samples` and `get_data_samples` cannot both be None.")
|
||||
if train_samples is not None:
|
||||
self.train_samples = train_samples
|
||||
self.eval_samples = eval_samples
|
||||
else:
|
||||
# load data for `vocoder`models
|
||||
self.data_eval, self.data_train = load_wav_data(self.config.data_path, self.config.eval_split_size)
|
||||
self.train_samples, self.eval_samples = self.run_get_data_samples(config, get_data_samples)
|
||||
|
||||
# init TTS model
|
||||
if model is None and get_model is None:
|
||||
raise ValueError("[!] `model` and `get_model` cannot both be None.")
|
||||
if model is not None:
|
||||
self.model = model
|
||||
else:
|
||||
self.model = self.get_model(self.config)
|
||||
self.run_get_model(self.config, get_model)
|
||||
|
||||
# TODO: out!
|
||||
# init multispeaker settings of the model
|
||||
if hasattr(self.model, "init_multispeaker"):
|
||||
self.model.init_multispeaker(self.config, self.data_train + self.data_eval)
|
||||
self.model.init_multispeaker(self.config, self.train_samples + self.eval_samples)
|
||||
|
||||
# setup criterion
|
||||
self.criterion = self.get_criterion(self.model)
|
||||
|
@ -247,7 +285,7 @@ class Trainer:
|
|||
# setup optimizer
|
||||
self.optimizer = self.get_optimizer(self.model, self.config)
|
||||
|
||||
# callback
|
||||
# CALLBACK
|
||||
self.callbacks = TrainerCallback(self)
|
||||
self.callbacks.on_init_start()
|
||||
|
||||
|
@ -280,7 +318,7 @@ class Trainer:
|
|||
else:
|
||||
self.scheduler.last_epoch = self.restore_step
|
||||
|
||||
# DISTRUBUTED
|
||||
# DISTRIBUTED
|
||||
if self.num_gpus > 1:
|
||||
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
|
||||
|
||||
|
@ -291,8 +329,54 @@ class Trainer:
|
|||
self.callbacks.on_init_end()
|
||||
|
||||
@staticmethod
|
||||
def get_model(config: Coqpit) -> nn.Module:
|
||||
"""Initialize model from config.
|
||||
def parse_argv(args: Union[Coqpit, List]):
|
||||
"""Parse command line arguments to init or override `TrainingArgs()`."""
|
||||
if isinstance(args, Coqpit):
|
||||
parser = args.init_argparse(arg_prefix="")
|
||||
else:
|
||||
train_config = TrainingArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
training_args, coqpit_overrides = parser.parse_known_args()
|
||||
args.parse_args(training_args)
|
||||
return args, coqpit_overrides
|
||||
|
||||
def init_training(self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None):
|
||||
"""Initialize training and update model configs from command line arguments.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||
config_overrides (argparse.Namespace or dict like): Parsed config overriding arguments.
|
||||
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
|
||||
|
||||
Returns:
|
||||
c (TTS.utils.io.AttrDict): Config paramaters.
|
||||
"""
|
||||
# set arguments for continuing training
|
||||
if args.continue_path:
|
||||
experiment_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, "config.json")
|
||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||
if not args.best_path:
|
||||
args.best_path = best_model
|
||||
|
||||
# override config values from command-line args
|
||||
# TODO: Maybe it is better to do it outside
|
||||
if len(coqpit_overrides) > 0:
|
||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||
experiment_path = args.continue_path
|
||||
|
||||
# update the config.json fields and copy it to the output folder
|
||||
if args.rank == 0:
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def run_get_model(config: Coqpit, get_model: Callable) -> nn.Module:
|
||||
"""Run the `get_model` function and return the model.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model config.
|
||||
|
@ -300,12 +384,23 @@ class Trainer:
|
|||
Returns:
|
||||
nn.Module: initialized model.
|
||||
"""
|
||||
try:
|
||||
model = setup_vocoder_model(config)
|
||||
except ModuleNotFoundError:
|
||||
model = setup_tts_model(config)
|
||||
if len(signature(get_model).sig.parameters) == 1:
|
||||
model = get_model(config)
|
||||
else:
|
||||
model = get_model()
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module:
|
||||
if isinstance(get_data_samples, Callable):
|
||||
if len(signature(get_data_samples).sig.parameters) == 1:
|
||||
train_samples, eval_samples = get_data_samples(config)
|
||||
else:
|
||||
train_samples, eval_samples = get_data_samples()
|
||||
return train_samples, eval_samples
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def restore_model(
|
||||
self,
|
||||
config: Coqpit,
|
||||
|
@ -366,11 +461,15 @@ class Trainer:
|
|||
torch.cuda.empty_cache()
|
||||
return model, optimizer, scaler, restore_step
|
||||
|
||||
#########################
|
||||
# DATA LOADING FUNCTIONS
|
||||
#########################
|
||||
|
||||
def _get_loader(
|
||||
self,
|
||||
model: nn.Module,
|
||||
config: Coqpit,
|
||||
ap: AudioProcessor,
|
||||
assets: Dict,
|
||||
is_eval: bool,
|
||||
data_items: List,
|
||||
verbose: bool,
|
||||
|
@ -379,14 +478,14 @@ class Trainer:
|
|||
if num_gpus > 1:
|
||||
if hasattr(model.module, "get_data_loader"):
|
||||
loader = model.module.get_data_loader(
|
||||
config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank
|
||||
config, assets, is_eval, data_items, verbose, num_gpus, self.args.rank
|
||||
)
|
||||
else:
|
||||
if hasattr(model, "get_data_loader"):
|
||||
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
|
||||
loader = model.get_data_loader(config, assets, is_eval, data_items, verbose, num_gpus)
|
||||
return loader
|
||||
|
||||
def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
|
||||
def get_train_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader:
|
||||
"""Initialize and return a training data loader.
|
||||
|
||||
Args:
|
||||
|
@ -397,10 +496,10 @@ class Trainer:
|
|||
Returns:
|
||||
DataLoader: Initialized training data loader.
|
||||
"""
|
||||
return self._get_loader(self.model, self.config, ap, False, data_items, verbose, self.num_gpus)
|
||||
return self._get_loader(self.model, self.config, training_assets, False, data_items, verbose, self.num_gpus)
|
||||
|
||||
def get_eval_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
|
||||
return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus)
|
||||
def get_eval_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader:
|
||||
return self._get_loader(self.model, self.config, training_assets, True, data_items, verbose, self.num_gpus)
|
||||
|
||||
def format_batch(self, batch: List) -> Dict:
|
||||
"""Format the dataloader output and return a batch.
|
||||
|
@ -420,6 +519,10 @@ class Trainer:
|
|||
batch[k] = to_cuda(v)
|
||||
return batch
|
||||
|
||||
######################
|
||||
# TRAIN FUNCTIONS
|
||||
######################
|
||||
|
||||
@staticmethod
|
||||
def master_params(optimizer: torch.optim.Optimizer):
|
||||
"""Generator over parameters owned by the optimizer.
|
||||
|
@ -567,24 +670,6 @@ class Trainer:
|
|||
loss_dict["grad_norm"] = grad_norm
|
||||
return outputs, loss_dict, step_time
|
||||
|
||||
@staticmethod
|
||||
def _detach_loss_dict(loss_dict: Dict) -> Dict:
|
||||
"""Detach loss values from autograp.
|
||||
|
||||
Args:
|
||||
loss_dict (Dict): losses.
|
||||
|
||||
Returns:
|
||||
Dict: losses detached from autograph.
|
||||
"""
|
||||
loss_dict_detached = {}
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_detached[key] = value
|
||||
else:
|
||||
loss_dict_detached[key] = value.item()
|
||||
return loss_dict_detached
|
||||
|
||||
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
||||
"""Perform a training step on a batch of inputs and log the process.
|
||||
|
||||
|
@ -700,15 +785,14 @@ class Trainer:
|
|||
self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
|
||||
|
||||
# training visualizations
|
||||
figures, audios = None, None
|
||||
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
|
||||
figures, audios = self.model.module.train_log(self.ap, batch, outputs)
|
||||
self.model.module.train_log(
|
||||
batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
|
||||
)
|
||||
elif hasattr(self.model, "train_log"):
|
||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||
if figures is not None:
|
||||
self.dashboard_logger.train_figures(self.total_steps_done, figures)
|
||||
if audios is not None:
|
||||
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
self.model.train_log(
|
||||
batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
|
||||
)
|
||||
|
||||
self.dashboard_logger.flush()
|
||||
|
||||
|
@ -718,11 +802,13 @@ class Trainer:
|
|||
|
||||
def train_epoch(self) -> None:
|
||||
"""Main entry point for the training loop. Run training on the all training samples."""
|
||||
# initialize the data loader
|
||||
self.train_loader = self.get_train_dataloader(
|
||||
self.ap,
|
||||
self.data_train,
|
||||
self.training_assets,
|
||||
self.train_samples,
|
||||
verbose=True,
|
||||
)
|
||||
# set model to training mode
|
||||
if self.num_gpus > 1:
|
||||
self.model.module.train()
|
||||
else:
|
||||
|
@ -734,11 +820,12 @@ class Trainer:
|
|||
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
|
||||
self.c_logger.print_train_start()
|
||||
loader_start_time = time.time()
|
||||
# iterate over the training samples
|
||||
for cur_step, batch in enumerate(self.train_loader):
|
||||
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
|
||||
loader_start_time = time.time()
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
# Plot self.epochs_done Stats
|
||||
# plot self.epochs_done Stats
|
||||
if self.args.rank == 0:
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(self.keep_avg_train.avg_values)
|
||||
|
@ -754,6 +841,10 @@ class Trainer:
|
|||
else:
|
||||
self.scheduler.step()
|
||||
|
||||
#######################
|
||||
# EVAL FUNCTIONS
|
||||
#######################
|
||||
|
||||
@staticmethod
|
||||
def _model_eval_step(
|
||||
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
|
||||
|
@ -819,8 +910,8 @@ class Trainer:
|
|||
"""Main entry point for the evaluation loop. Run evaluation on the all validation samples."""
|
||||
self.eval_loader = (
|
||||
self.get_eval_dataloader(
|
||||
self.ap,
|
||||
self.data_eval,
|
||||
self.training_assets,
|
||||
self.eval_samples,
|
||||
verbose=True,
|
||||
)
|
||||
if self.config.run_eval
|
||||
|
@ -840,15 +931,12 @@ class Trainer:
|
|||
loader_start_time = time.time()
|
||||
# plot epoch stats, artifacts and figures
|
||||
if self.args.rank == 0:
|
||||
figures, audios = None, None
|
||||
if hasattr(self.model, "module") and hasattr(self.model.module, "eval_log"):
|
||||
figures, audios = self.model.module.eval_log(self.ap, batch, outputs)
|
||||
self.model.module.eval_log(
|
||||
batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
|
||||
)
|
||||
elif hasattr(self.model, "eval_log"):
|
||||
figures, audios = self.model.eval_log(self.ap, batch, outputs)
|
||||
if figures is not None:
|
||||
self.dashboard_logger.eval_figures(self.total_steps_done, figures)
|
||||
if audios is not None:
|
||||
self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
self.model.eval_log(batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done)
|
||||
self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
||||
|
||||
def test_run(self) -> None:
|
||||
|
@ -857,22 +945,22 @@ class Trainer:
|
|||
if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
|
||||
if self.eval_loader is None:
|
||||
self.eval_loader = self.get_eval_dataloader(
|
||||
self.ap,
|
||||
self.data_eval,
|
||||
self.training_assets,
|
||||
self.eval_samples,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||
if self.num_gpus > 1:
|
||||
figures, audios = self.model.module.test_run(self.ap, samples, None)
|
||||
figures, audios = self.model.module.test_run(self.training_assets, samples, None)
|
||||
else:
|
||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||
figures, audios = self.model.test_run(self.training_assets, samples, None)
|
||||
else:
|
||||
if self.num_gpus > 1:
|
||||
figures, audios = self.model.module.test_run(self.ap)
|
||||
figures, audios = self.model.module.test_run(self.training_assets)
|
||||
else:
|
||||
figures, audios = self.model.test_run(self.ap)
|
||||
figures, audios = self.model.test_run(self.training_assets)
|
||||
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
||||
|
||||
|
@ -886,6 +974,10 @@ class Trainer:
|
|||
self.best_loss = ch["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||
|
||||
###################################
|
||||
# FIT FUNCTIONS
|
||||
###################################
|
||||
|
||||
def _fit(self) -> None:
|
||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||
self._restore_best_loss()
|
||||
|
@ -901,6 +993,7 @@ class Trainer:
|
|||
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||
self.epochs_done = epoch
|
||||
self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
|
||||
if not self.args.skip_train_epoch:
|
||||
self.train_epoch()
|
||||
if self.config.run_eval:
|
||||
self.eval_epoch()
|
||||
|
@ -939,24 +1032,6 @@ class Trainer:
|
|||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
|
||||
"""Pick the target loss to compare models"""
|
||||
target_avg_loss = None
|
||||
|
||||
# return if target loss defined in the model config
|
||||
if "target_loss" in self.config and self.config.target_loss:
|
||||
return keep_avg_target[f"avg_{self.config.target_loss}"]
|
||||
|
||||
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
|
||||
if isinstance(self.optimizer, list):
|
||||
target_avg_loss = 0
|
||||
for idx in range(len(self.optimizer)):
|
||||
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
|
||||
target_avg_loss /= len(self.optimizer)
|
||||
else:
|
||||
target_avg_loss = keep_avg_target["avg_loss"]
|
||||
return target_avg_loss
|
||||
|
||||
def save_best_model(self) -> None:
|
||||
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
|
||||
|
||||
|
@ -978,35 +1053,9 @@ class Trainer:
|
|||
keep_after=self.config.keep_after,
|
||||
)
|
||||
|
||||
def _setup_logger_config(self, log_file: str) -> None:
|
||||
"""Write log strings to a file and print logs to the terminal.
|
||||
TODO: Causes formatting issues in pdb debugging."""
|
||||
|
||||
class Logger(object):
|
||||
def __init__(self, print_to_terminal=True):
|
||||
self.print_to_terminal = print_to_terminal
|
||||
self.terminal = sys.stdout
|
||||
self.log_file = log_file
|
||||
|
||||
def write(self, message):
|
||||
if self.print_to_terminal:
|
||||
self.terminal.write(message)
|
||||
with open(self.log_file, "a", encoding="utf-8") as f:
|
||||
f.write(message)
|
||||
|
||||
def flush(self):
|
||||
# this flush method is needed for python 3 compatibility.
|
||||
# this handles the flush command by doing nothing.
|
||||
# you might want to specify some extra behavior here.
|
||||
pass
|
||||
|
||||
# don't let processes rank > 0 write to the terminal
|
||||
sys.stdout = Logger(self.args.rank == 0)
|
||||
|
||||
@staticmethod
|
||||
def _is_apex_available() -> bool:
|
||||
"""Check if Nvidia's APEX is available."""
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
#####################
|
||||
# GET FUNCTIONS
|
||||
#####################
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]:
|
||||
|
@ -1084,154 +1133,72 @@ class Trainer:
|
|||
criterion = model.get_criterion()
|
||||
return criterion
|
||||
|
||||
####################
|
||||
# HELPER FUNCTIONS
|
||||
####################
|
||||
|
||||
def getarguments():
|
||||
train_config = TrainingArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
||||
def get_last_checkpoint(path: str) -> Tuple[str, str]:
|
||||
"""Get latest checkpoint or/and best model in path.
|
||||
|
||||
It is based on globbing for `*.pth.tar` and the RegEx
|
||||
`(checkpoint|best_model)_([0-9]+)`.
|
||||
@staticmethod
|
||||
def _detach_loss_dict(loss_dict: Dict) -> Dict:
|
||||
"""Detach loss values from autograp.
|
||||
|
||||
Args:
|
||||
path: Path to files to be compared.
|
||||
|
||||
Raises:
|
||||
ValueError: If no checkpoint or best_model files are found.
|
||||
loss_dict (Dict): losses.
|
||||
|
||||
Returns:
|
||||
Path to the last checkpoint
|
||||
Path to best checkpoint
|
||||
Dict: losses detached from autograph.
|
||||
"""
|
||||
fs = fsspec.get_mapper(path).fs
|
||||
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
|
||||
scheme = urlparse(path).scheme
|
||||
if scheme: # scheme is not preserved in fs.glob, add it back
|
||||
file_names = [scheme + "://" + file_name for file_name in file_names]
|
||||
last_models = {}
|
||||
last_model_nums = {}
|
||||
for key in ["checkpoint", "best_model"]:
|
||||
last_model_num = None
|
||||
last_model = None
|
||||
# pass all the checkpoint files and find
|
||||
# the one with the largest model number suffix.
|
||||
for file_name in file_names:
|
||||
match = re.search(f"{key}_([0-9]+)", file_name)
|
||||
if match is not None:
|
||||
model_num = int(match.groups()[0])
|
||||
if last_model_num is None or model_num > last_model_num:
|
||||
last_model_num = model_num
|
||||
last_model = file_name
|
||||
|
||||
# if there is no checkpoint found above
|
||||
# find the checkpoint with the latest
|
||||
# modification date.
|
||||
key_file_names = [fn for fn in file_names if key in fn]
|
||||
if last_model is None and len(key_file_names) > 0:
|
||||
last_model = max(key_file_names, key=os.path.getctime)
|
||||
last_model_num = load_fsspec(last_model)["step"]
|
||||
|
||||
if last_model is not None:
|
||||
last_models[key] = last_model
|
||||
last_model_nums[key] = last_model_num
|
||||
|
||||
# check what models were found
|
||||
if not last_models:
|
||||
raise ValueError(f"No models found in continue path {path}!")
|
||||
if "checkpoint" not in last_models: # no checkpoint just best model
|
||||
last_models["checkpoint"] = last_models["best_model"]
|
||||
elif "best_model" not in last_models: # no best model
|
||||
# this shouldn't happen, but let's handle it just in case
|
||||
last_models["best_model"] = last_models["checkpoint"]
|
||||
# finally check if last best model is more recent than checkpoint
|
||||
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
|
||||
last_models["checkpoint"] = last_models["best_model"]
|
||||
|
||||
return last_models["checkpoint"], last_models["best_model"]
|
||||
|
||||
|
||||
def process_args(args, config=None):
|
||||
"""Process parsed comand line arguments and initialize the config if not provided.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
|
||||
|
||||
Returns:
|
||||
c (TTS.utils.io.AttrDict): Config paramaters.
|
||||
out_path (str): Path to save models and logging.
|
||||
audio_path (str): Path to save generated test audios.
|
||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||
logging to the console.
|
||||
|
||||
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
|
||||
|
||||
TODO:
|
||||
- Interactive config definition.
|
||||
"""
|
||||
if isinstance(args, tuple):
|
||||
args, coqpit_overrides = args
|
||||
if args.continue_path:
|
||||
# continue a previous training from its output folder
|
||||
experiment_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, "config.json")
|
||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||
if not args.best_path:
|
||||
args.best_path = best_model
|
||||
# init config if not already defined
|
||||
if config is None:
|
||||
if args.config_path:
|
||||
# init from a file
|
||||
config = load_config(args.config_path)
|
||||
loss_dict_detached = {}
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_detached[key] = value
|
||||
else:
|
||||
# init from console args
|
||||
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
||||
loss_dict_detached[key] = value.detach()
|
||||
return loss_dict_detached
|
||||
|
||||
config_base = BaseTrainingConfig()
|
||||
config_base.parse_known_args(coqpit_overrides)
|
||||
config = register_config(config_base.model)()
|
||||
# override values from command-line args
|
||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||
experiment_path = args.continue_path
|
||||
if not experiment_path:
|
||||
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||
audio_path = os.path.join(experiment_path, "test_audios")
|
||||
config.output_log_path = experiment_path
|
||||
# setup rank 0 process in distributed training
|
||||
dashboard_logger = None
|
||||
if args.rank == 0:
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
# if model characters are not set in the config file
|
||||
# save the default set to the config file for future
|
||||
# compatibility.
|
||||
if config.has("characters") and config.characters is None:
|
||||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
dashboard_logger = init_dashboard_logger(config)
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
|
||||
"""Pick the target loss to compare models"""
|
||||
target_avg_loss = None
|
||||
|
||||
# return if target loss defined in the model config
|
||||
if "target_loss" in self.config and self.config.target_loss:
|
||||
return keep_avg_target[f"avg_{self.config.target_loss}"]
|
||||
|
||||
def init_arguments():
|
||||
train_config = TrainingArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
||||
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
||||
"""Initialization of a training run."""
|
||||
if isinstance(argv, Coqpit):
|
||||
parser = argv.init_argparse(arg_prefix="")
|
||||
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
|
||||
if isinstance(self.optimizer, list):
|
||||
target_avg_loss = 0
|
||||
for idx in range(len(self.optimizer)):
|
||||
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
|
||||
target_avg_loss /= len(self.optimizer)
|
||||
else:
|
||||
parser = init_arguments()
|
||||
args = parser.parse_known_args()
|
||||
config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
|
||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger
|
||||
target_avg_loss = keep_avg_target["avg_loss"]
|
||||
return target_avg_loss
|
||||
|
||||
def _setup_logger_config(self, log_file: str) -> None:
|
||||
"""Write log strings to a file and print logs to the terminal.
|
||||
TODO: Causes formatting issues in pdb debugging."""
|
||||
|
||||
class Logger(object):
|
||||
def __init__(self, print_to_terminal=True):
|
||||
self.print_to_terminal = print_to_terminal
|
||||
self.terminal = sys.stdout
|
||||
self.log_file = log_file
|
||||
|
||||
def write(self, message):
|
||||
if self.print_to_terminal:
|
||||
self.terminal.write(message)
|
||||
with open(self.log_file, "a", encoding="utf-8") as f:
|
||||
f.write(message)
|
||||
|
||||
def flush(self):
|
||||
# this flush method is needed for python 3 compatibility.
|
||||
# this handles the flush command by doing nothing.
|
||||
# you might want to specify some extra behavior here.
|
||||
pass
|
||||
|
||||
# don't let processes rank > 0 write to the terminal
|
||||
sys.stdout = Logger(self.args.rank == 0)
|
||||
|
||||
@staticmethod
|
||||
def _is_apex_available() -> bool:
|
||||
"""Check if Nvidia's APEX is available."""
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
|
|||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.model import BaseModel
|
||||
from TTS.tts.datasets import TTSDataset
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
|
@ -32,6 +32,30 @@ class BaseTTS(BaseModel):
|
|||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Setup model args based on the config type.
|
||||
|
||||
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||
config.model_args
|
||||
|
||||
If the config is for the model with a name like "*Args", then we assign the directly.
|
||||
"""
|
||||
# don't use isintance not to import recursively
|
||||
if "Config" in config.__class__.__name__:
|
||||
if "characters" in config:
|
||||
_, self.config, num_chars = self.get_characters(config)
|
||||
self.config.num_chars = num_chars
|
||||
if hasattr(self.config, "model_args"):
|
||||
config.model_args.num_chars = num_chars
|
||||
self.args = self.config.model_args
|
||||
else:
|
||||
self.config = config
|
||||
self.args = config.model_args
|
||||
elif "Args" in config.__class__.__name__:
|
||||
self.args = config
|
||||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
|
||||
@staticmethod
|
||||
def get_characters(config: Coqpit) -> str:
|
||||
# TODO: implement CharacterProcessor
|
||||
|
@ -169,7 +193,7 @@ class BaseTTS(BaseModel):
|
|||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: AudioProcessor,
|
||||
assets: Dict,
|
||||
is_eval: bool,
|
||||
data_items: List,
|
||||
verbose: bool,
|
||||
|
@ -179,6 +203,8 @@ class BaseTTS(BaseModel):
|
|||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
ap = assets["audio_processor"]
|
||||
|
||||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager"):
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||
|
@ -280,14 +306,18 @@ class BaseTTS(BaseModel):
|
|||
)
|
||||
return loader
|
||||
|
||||
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
||||
Args:
|
||||
assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
ap = assets["audio_processor"]
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
import importlib
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.training import NoamLR
|
||||
|
||||
|
||||
|
@ -80,3 +85,66 @@ def get_optimizer(
|
|||
if model is not None:
|
||||
parameters = model.parameters()
|
||||
return optimizer(parameters, lr=lr, **optimizer_params)
|
||||
|
||||
|
||||
def get_last_checkpoint(path: str) -> Tuple[str, str]:
|
||||
"""Get latest checkpoint or/and best model in path.
|
||||
|
||||
It is based on globbing for `*.pth.tar` and the RegEx
|
||||
`(checkpoint|best_model)_([0-9]+)`.
|
||||
|
||||
Args:
|
||||
path: Path to files to be compared.
|
||||
|
||||
Raises:
|
||||
ValueError: If no checkpoint or best_model files are found.
|
||||
|
||||
Returns:
|
||||
Path to the last checkpoint
|
||||
Path to best checkpoint
|
||||
"""
|
||||
fs = fsspec.get_mapper(path).fs
|
||||
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
|
||||
scheme = urlparse(path).scheme
|
||||
if scheme: # scheme is not preserved in fs.glob, add it back
|
||||
file_names = [scheme + "://" + file_name for file_name in file_names]
|
||||
last_models = {}
|
||||
last_model_nums = {}
|
||||
for key in ["checkpoint", "best_model"]:
|
||||
last_model_num = None
|
||||
last_model = None
|
||||
# pass all the checkpoint files and find
|
||||
# the one with the largest model number suffix.
|
||||
for file_name in file_names:
|
||||
match = re.search(f"{key}_([0-9]+)", file_name)
|
||||
if match is not None:
|
||||
model_num = int(match.groups()[0])
|
||||
if last_model_num is None or model_num > last_model_num:
|
||||
last_model_num = model_num
|
||||
last_model = file_name
|
||||
|
||||
# if there is no checkpoint found above
|
||||
# find the checkpoint with the latest
|
||||
# modification date.
|
||||
key_file_names = [fn for fn in file_names if key in fn]
|
||||
if last_model is None and len(key_file_names) > 0:
|
||||
last_model = max(key_file_names, key=os.path.getctime)
|
||||
last_model_num = load_fsspec(last_model)["step"]
|
||||
|
||||
if last_model is not None:
|
||||
last_models[key] = last_model
|
||||
last_model_nums[key] = last_model_num
|
||||
|
||||
# check what models were found
|
||||
if not last_models:
|
||||
raise ValueError(f"No models found in continue path {path}!")
|
||||
if "checkpoint" not in last_models: # no checkpoint just best model
|
||||
last_models["checkpoint"] = last_models["best_model"]
|
||||
elif "best_model" not in last_models: # no best model
|
||||
# this shouldn't happen, but let's handle it just in case
|
||||
last_models["best_model"] = last_models["checkpoint"]
|
||||
# finally check if last best model is more recent than checkpoint
|
||||
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
|
||||
last_models["checkpoint"] = last_models["best_model"]
|
||||
|
||||
return last_models["checkpoint"], last_models["best_model"]
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from coqpit import Coqpit
|
||||
|
||||
from TTS.model import BaseModel
|
||||
|
||||
# pylint: skip-file
|
||||
|
@ -16,5 +18,35 @@ class BaseVocoder(BaseModel):
|
|||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Setup model args based on the config type.
|
||||
|
||||
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||
config.model_args
|
||||
|
||||
If the config is for the model with a name like "*Args", then we assign the directly.
|
||||
"""
|
||||
# don't use isintance not to import recursively
|
||||
if "Config" in config.__class__.__name__:
|
||||
if "characters" in config:
|
||||
_, self.config, num_chars = self.get_characters(config)
|
||||
self.config.num_chars = num_chars
|
||||
if hasattr(self.config, "model_args"):
|
||||
config.model_args.num_chars = num_chars
|
||||
if "model_args" in config:
|
||||
self.args = self.config.model_args
|
||||
# This is for backward compatibility
|
||||
if "model_params" in config:
|
||||
self.args = self.config.model_params
|
||||
else:
|
||||
self.config = config
|
||||
if "model_args" in config:
|
||||
self.args = self.config.model_args
|
||||
# This is for backward compatibility
|
||||
if "model_params" in config:
|
||||
self.args = self.config.model_params
|
||||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
|
|
Loading…
Reference in New Issue