Refactor `trainer.py` for v2

This commit is contained in:
Eren Gölge 2021-09-30 14:16:34 +00:00
parent 7f388f26e3
commit 8ada870a57
5 changed files with 388 additions and 291 deletions

View File

@ -4,16 +4,14 @@ import importlib
import multiprocessing
import os
import platform
import re
import sys
import time
import traceback
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
from urllib.parse import urlparse
from inspect import signature
from typing import Callable, Dict, List, Tuple, Union
import fsspec
import torch
import torch.distributed as dist
from coqpit import Coqpit
@ -21,11 +19,7 @@ from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
from TTS.config import load_config, register_config
from TTS.tts.datasets import load_meta_data
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.audio import AudioProcessor
from TTS.stt.datasets.tokenizer import Tokenizer
from TTS.utils.callbacks import TrainerCallback
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import (
@ -39,9 +33,13 @@ from TTS.utils.generic_utils import (
)
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.utils.trainer_utils import (
get_last_checkpoint,
get_optimizer,
get_scheduler,
is_apex_available,
setup_torch_training_env,
)
multiprocessing.set_start_method("fork")
@ -80,6 +78,9 @@ class TrainingArgs(Coqpit):
"help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
},
)
skip_train_epoch: bool = field(
default=False, metadata={"help": "Run only evaluation iteration. Useful for debugging."}
)
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
@ -98,7 +99,14 @@ class Trainer:
c_logger: ConsoleLogger = None,
dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
model: nn.Module = None,
get_model: Callable = None,
get_data_samples: Callable = None,
train_samples: List = None,
eval_samples: List = None,
tokenizer: Tokenizer = None,
cudnn_benchmark: bool = False,
training_assets: Dict = {},
parse_command_line_args: bool = True,
) -> None:
"""Simple yet powerful 🐸💬 TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models
or easily be customized.
@ -127,24 +135,44 @@ class Trainer:
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
initializes a model from the provided config. Defaults to None.
get_model (Callable):
A function that returns a model. It is used to initialize the model when `model` is not provided.
It either takes the config as the only argument or does not take any argument.
Defaults to None
get_data_samples (Callable):
A function that returns a list of training and evaluation samples. Used if `train_samples` and
`eval_samples` are None. Defaults to None.
train_samples (List):
A list of training samples used by the model's `get_data_loader` to init the `dataset` and the
`data_loader`. Defaults to None.
eval_samples (List):
A list of evaluation samples used by the model's `get_data_loader` to init the `dataset` and the
`data_loader`. Defaults to None.
cudnn_benchmark (bool): enable/disable PyTorch cudnn benchmarking. It is better to disable if the model input
length is changing batch to batch along the training.
training_assets (Dict):
A dictionary of assets to be used at training and passed to the model's ```train_log(), eval_log(), get_data_loader()```
during training. It can include `AudioProcessor` or/and `Tokenizer`. Defaults to {}.
parse_command_line_args (bool):
If true, parse command-line arguments and update `TrainingArgs` and model `config` values. Set it
to false if you parse the arguments yourself. Defaults to True.
Examples:
Running trainer on a model.
Running trainer with HifiGAN model.
>>> args = TrainingArgs(...)
>>> config = HifiganConfig(...)
>>> model = GANModel(config)
>>> trainer = Trainer(args, config, output_path, model=model)
>>> trainer.fit()
Running trainer on a config.
>>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
>>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
>>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
>>> ap = AudioProcessor(**config.audio)
>>> assets = {"audio_processor": ap}
>>> trainer = Trainer(args, config, output_path, model=model, training_assets=assets)
>>> trainer.fit()
TODO:
@ -154,20 +182,33 @@ class Trainer:
- Profiler integration.
- Overfitting to a batch.
- TPU training
- NOTE: Consider moving `training_assets` to the model implementation.
"""
if parse_command_line_args:
# parse command-line arguments for TrainingArgs()
args, coqpit_overrides = self.parse_argv(args)
if config is None:
# parse config from console arguments
config, output_path, _, c_logger, dashboard_logger = process_args(args)
# get ready for training and parse command-line arguments for the model config
config = self.init_training(args, coqpit_overrides, config)
# define the experiment path and create the folder
output_path = get_experiment_folder_path(config.output_path, config.run_name)
os.makedirs(output_path, exist_ok=True)
# copy training assets to the output folder
copy_model_files(config, output_path, new_fields=None)
# init class members
self.args = args
self.config = config
self.output_path = output_path
self.config.output_log_path = output_path
self.training_assets = training_assets
# setup logging
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
time.sleep(1.0) # wait for the logger to be ready
# set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
@ -196,33 +237,30 @@ class Trainer:
self.use_apex = self._is_apex_available()
self.use_amp_scaler = self.config.mixed_precision and self.use_cuda
# init audio processor
self.ap = AudioProcessor(**self.config.audio.to_dict())
# init tokenizer
self.tokenizer = tokenizer
# load data samples
# TODO: refactor this
if "datasets" in self.config:
# load data for `tts` models
self.data_train, self.data_eval = load_meta_data(self.config.datasets)
elif self.config.feature_path is not None:
# load pre-comnputed features for `vocoder`models
print(f" > Loading features from: {self.config.feature_path}")
self.data_eval, self.data_train = load_wav_feat_data(
self.config.data_path, self.config.feature_path, self.config.eval_split_size
)
if train_samples is None and get_data_samples is None:
raise ValueError("[!] `train_samples` and `get_data_samples` cannot both be None.")
if train_samples is not None:
self.train_samples = train_samples
self.eval_samples = eval_samples
else:
# load data for `vocoder`models
self.data_eval, self.data_train = load_wav_data(self.config.data_path, self.config.eval_split_size)
self.train_samples, self.eval_samples = self.run_get_data_samples(config, get_data_samples)
# init TTS model
if model is None and get_model is None:
raise ValueError("[!] `model` and `get_model` cannot both be None.")
if model is not None:
self.model = model
else:
self.model = self.get_model(self.config)
self.run_get_model(self.config, get_model)
# TODO: out!
# init multispeaker settings of the model
if hasattr(self.model, "init_multispeaker"):
self.model.init_multispeaker(self.config, self.data_train + self.data_eval)
self.model.init_multispeaker(self.config, self.train_samples + self.eval_samples)
# setup criterion
self.criterion = self.get_criterion(self.model)
@ -247,7 +285,7 @@ class Trainer:
# setup optimizer
self.optimizer = self.get_optimizer(self.model, self.config)
# callback
# CALLBACK
self.callbacks = TrainerCallback(self)
self.callbacks.on_init_start()
@ -280,7 +318,7 @@ class Trainer:
else:
self.scheduler.last_epoch = self.restore_step
# DISTRUBUTED
# DISTRIBUTED
if self.num_gpus > 1:
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
@ -291,8 +329,54 @@ class Trainer:
self.callbacks.on_init_end()
@staticmethod
def get_model(config: Coqpit) -> nn.Module:
"""Initialize model from config.
def parse_argv(args: Union[Coqpit, List]):
"""Parse command line arguments to init or override `TrainingArgs()`."""
if isinstance(args, Coqpit):
parser = args.init_argparse(arg_prefix="")
else:
train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="")
training_args, coqpit_overrides = parser.parse_known_args()
args.parse_args(training_args)
return args, coqpit_overrides
def init_training(self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None):
"""Initialize training and update model configs from command line arguments.
Args:
args (argparse.Namespace or dict like): Parsed input arguments.
config_overrides (argparse.Namespace or dict like): Parsed config overriding arguments.
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
Returns:
c (TTS.utils.io.AttrDict): Config paramaters.
"""
# set arguments for continuing training
if args.continue_path:
experiment_path = args.continue_path
args.config_path = os.path.join(args.continue_path, "config.json")
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model
# override config values from command-line args
# TODO: Maybe it is better to do it outside
if len(coqpit_overrides) > 0:
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
experiment_path = args.continue_path
# update the config.json fields and copy it to the output folder
if args.rank == 0:
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_model_files(config, experiment_path, new_fields)
return config
@staticmethod
def run_get_model(config: Coqpit, get_model: Callable) -> nn.Module:
"""Run the `get_model` function and return the model.
Args:
config (Coqpit): Model config.
@ -300,12 +384,23 @@ class Trainer:
Returns:
nn.Module: initialized model.
"""
try:
model = setup_vocoder_model(config)
except ModuleNotFoundError:
model = setup_tts_model(config)
if len(signature(get_model).sig.parameters) == 1:
model = get_model(config)
else:
model = get_model()
return model
@staticmethod
def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module:
if isinstance(get_data_samples, Callable):
if len(signature(get_data_samples).sig.parameters) == 1:
train_samples, eval_samples = get_data_samples(config)
else:
train_samples, eval_samples = get_data_samples()
return train_samples, eval_samples
else:
return None, None
def restore_model(
self,
config: Coqpit,
@ -366,11 +461,15 @@ class Trainer:
torch.cuda.empty_cache()
return model, optimizer, scaler, restore_step
#########################
# DATA LOADING FUNCTIONS
#########################
def _get_loader(
self,
model: nn.Module,
config: Coqpit,
ap: AudioProcessor,
assets: Dict,
is_eval: bool,
data_items: List,
verbose: bool,
@ -379,14 +478,14 @@ class Trainer:
if num_gpus > 1:
if hasattr(model.module, "get_data_loader"):
loader = model.module.get_data_loader(
config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank
config, assets, is_eval, data_items, verbose, num_gpus, self.args.rank
)
else:
if hasattr(model, "get_data_loader"):
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
loader = model.get_data_loader(config, assets, is_eval, data_items, verbose, num_gpus)
return loader
def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
def get_train_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader:
"""Initialize and return a training data loader.
Args:
@ -397,10 +496,10 @@ class Trainer:
Returns:
DataLoader: Initialized training data loader.
"""
return self._get_loader(self.model, self.config, ap, False, data_items, verbose, self.num_gpus)
return self._get_loader(self.model, self.config, training_assets, False, data_items, verbose, self.num_gpus)
def get_eval_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus)
def get_eval_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader:
return self._get_loader(self.model, self.config, training_assets, True, data_items, verbose, self.num_gpus)
def format_batch(self, batch: List) -> Dict:
"""Format the dataloader output and return a batch.
@ -420,6 +519,10 @@ class Trainer:
batch[k] = to_cuda(v)
return batch
######################
# TRAIN FUNCTIONS
######################
@staticmethod
def master_params(optimizer: torch.optim.Optimizer):
"""Generator over parameters owned by the optimizer.
@ -567,24 +670,6 @@ class Trainer:
loss_dict["grad_norm"] = grad_norm
return outputs, loss_dict, step_time
@staticmethod
def _detach_loss_dict(loss_dict: Dict) -> Dict:
"""Detach loss values from autograp.
Args:
loss_dict (Dict): losses.
Returns:
Dict: losses detached from autograph.
"""
loss_dict_detached = {}
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_detached[key] = value
else:
loss_dict_detached[key] = value.item()
return loss_dict_detached
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
"""Perform a training step on a batch of inputs and log the process.
@ -700,15 +785,14 @@ class Trainer:
self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
# training visualizations
figures, audios = None, None
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
figures, audios = self.model.module.train_log(self.ap, batch, outputs)
self.model.module.train_log(
batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
)
elif hasattr(self.model, "train_log"):
figures, audios = self.model.train_log(self.ap, batch, outputs)
if figures is not None:
self.dashboard_logger.train_figures(self.total_steps_done, figures)
if audios is not None:
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.model.train_log(
batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
)
self.dashboard_logger.flush()
@ -718,11 +802,13 @@ class Trainer:
def train_epoch(self) -> None:
"""Main entry point for the training loop. Run training on the all training samples."""
# initialize the data loader
self.train_loader = self.get_train_dataloader(
self.ap,
self.data_train,
self.training_assets,
self.train_samples,
verbose=True,
)
# set model to training mode
if self.num_gpus > 1:
self.model.module.train()
else:
@ -734,11 +820,12 @@ class Trainer:
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_train_start()
loader_start_time = time.time()
# iterate over the training samples
for cur_step, batch in enumerate(self.train_loader):
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
loader_start_time = time.time()
epoch_time = time.time() - epoch_start_time
# Plot self.epochs_done Stats
# plot self.epochs_done Stats
if self.args.rank == 0:
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(self.keep_avg_train.avg_values)
@ -754,6 +841,10 @@ class Trainer:
else:
self.scheduler.step()
#######################
# EVAL FUNCTIONS
#######################
@staticmethod
def _model_eval_step(
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
@ -803,7 +894,7 @@ class Trainer:
loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss")
loss_dict.update(loss_dict_new)
loss_dict = self._detach_loss_dict(loss_dict)
loss_dict = self._detach_loss_dict(loss_dict)
# update avg stats
update_eval_values = {}
@ -819,8 +910,8 @@ class Trainer:
"""Main entry point for the evaluation loop. Run evaluation on the all validation samples."""
self.eval_loader = (
self.get_eval_dataloader(
self.ap,
self.data_eval,
self.training_assets,
self.eval_samples,
verbose=True,
)
if self.config.run_eval
@ -840,15 +931,12 @@ class Trainer:
loader_start_time = time.time()
# plot epoch stats, artifacts and figures
if self.args.rank == 0:
figures, audios = None, None
if hasattr(self.model, "module") and hasattr(self.model.module, "eval_log"):
figures, audios = self.model.module.eval_log(self.ap, batch, outputs)
self.model.module.eval_log(
batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
)
elif hasattr(self.model, "eval_log"):
figures, audios = self.model.eval_log(self.ap, batch, outputs)
if figures is not None:
self.dashboard_logger.eval_figures(self.total_steps_done, figures)
if audios is not None:
self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.model.eval_log(batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done)
self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
def test_run(self) -> None:
@ -857,22 +945,22 @@ class Trainer:
if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
if self.eval_loader is None:
self.eval_loader = self.get_eval_dataloader(
self.ap,
self.data_eval,
self.training_assets,
self.eval_samples,
verbose=True,
)
if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1)
if self.num_gpus > 1:
figures, audios = self.model.module.test_run(self.ap, samples, None)
figures, audios = self.model.module.test_run(self.training_assets, samples, None)
else:
figures, audios = self.model.test_run(self.ap, samples, None)
figures, audios = self.model.test_run(self.training_assets, samples, None)
else:
if self.num_gpus > 1:
figures, audios = self.model.module.test_run(self.ap)
figures, audios = self.model.module.test_run(self.training_assets)
else:
figures, audios = self.model.test_run(self.ap)
figures, audios = self.model.test_run(self.training_assets)
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.dashboard_logger.test_figures(self.total_steps_done, figures)
@ -886,6 +974,10 @@ class Trainer:
self.best_loss = ch["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.")
###################################
# FIT FUNCTIONS
###################################
def _fit(self) -> None:
"""🏃 train -> evaluate -> test for the number of epochs."""
self._restore_best_loss()
@ -901,7 +993,8 @@ class Trainer:
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
self.epochs_done = epoch
self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
self.train_epoch()
if not self.args.skip_train_epoch:
self.train_epoch()
if self.config.run_eval:
self.eval_epoch()
if epoch >= self.config.test_delay_epochs and self.args.rank <= 0:
@ -939,24 +1032,6 @@ class Trainer:
traceback.print_exc()
sys.exit(1)
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
"""Pick the target loss to compare models"""
target_avg_loss = None
# return if target loss defined in the model config
if "target_loss" in self.config and self.config.target_loss:
return keep_avg_target[f"avg_{self.config.target_loss}"]
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
if isinstance(self.optimizer, list):
target_avg_loss = 0
for idx in range(len(self.optimizer)):
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
target_avg_loss /= len(self.optimizer)
else:
target_avg_loss = keep_avg_target["avg_loss"]
return target_avg_loss
def save_best_model(self) -> None:
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
@ -978,35 +1053,9 @@ class Trainer:
keep_after=self.config.keep_after,
)
def _setup_logger_config(self, log_file: str) -> None:
"""Write log strings to a file and print logs to the terminal.
TODO: Causes formatting issues in pdb debugging."""
class Logger(object):
def __init__(self, print_to_terminal=True):
self.print_to_terminal = print_to_terminal
self.terminal = sys.stdout
self.log_file = log_file
def write(self, message):
if self.print_to_terminal:
self.terminal.write(message)
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(message)
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
# don't let processes rank > 0 write to the terminal
sys.stdout = Logger(self.args.rank == 0)
@staticmethod
def _is_apex_available() -> bool:
"""Check if Nvidia's APEX is available."""
return importlib.util.find_spec("apex") is not None
#####################
# GET FUNCTIONS
#####################
@staticmethod
def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]:
@ -1084,154 +1133,72 @@ class Trainer:
criterion = model.get_criterion()
return criterion
####################
# HELPER FUNCTIONS
####################
def getarguments():
train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser
@staticmethod
def _detach_loss_dict(loss_dict: Dict) -> Dict:
"""Detach loss values from autograp.
Args:
loss_dict (Dict): losses.
def get_last_checkpoint(path: str) -> Tuple[str, str]:
"""Get latest checkpoint or/and best model in path.
Returns:
Dict: losses detached from autograph.
"""
loss_dict_detached = {}
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_detached[key] = value
else:
loss_dict_detached[key] = value.detach()
return loss_dict_detached
It is based on globbing for `*.pth.tar` and the RegEx
`(checkpoint|best_model)_([0-9]+)`.
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
"""Pick the target loss to compare models"""
target_avg_loss = None
Args:
path: Path to files to be compared.
# return if target loss defined in the model config
if "target_loss" in self.config and self.config.target_loss:
return keep_avg_target[f"avg_{self.config.target_loss}"]
Raises:
ValueError: If no checkpoint or best_model files are found.
Returns:
Path to the last checkpoint
Path to best checkpoint
"""
fs = fsspec.get_mapper(path).fs
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
scheme = urlparse(path).scheme
if scheme: # scheme is not preserved in fs.glob, add it back
file_names = [scheme + "://" + file_name for file_name in file_names]
last_models = {}
last_model_nums = {}
for key in ["checkpoint", "best_model"]:
last_model_num = None
last_model = None
# pass all the checkpoint files and find
# the one with the largest model number suffix.
for file_name in file_names:
match = re.search(f"{key}_([0-9]+)", file_name)
if match is not None:
model_num = int(match.groups()[0])
if last_model_num is None or model_num > last_model_num:
last_model_num = model_num
last_model = file_name
# if there is no checkpoint found above
# find the checkpoint with the latest
# modification date.
key_file_names = [fn for fn in file_names if key in fn]
if last_model is None and len(key_file_names) > 0:
last_model = max(key_file_names, key=os.path.getctime)
last_model_num = load_fsspec(last_model)["step"]
if last_model is not None:
last_models[key] = last_model
last_model_nums[key] = last_model_num
# check what models were found
if not last_models:
raise ValueError(f"No models found in continue path {path}!")
if "checkpoint" not in last_models: # no checkpoint just best model
last_models["checkpoint"] = last_models["best_model"]
elif "best_model" not in last_models: # no best model
# this shouldn't happen, but let's handle it just in case
last_models["best_model"] = last_models["checkpoint"]
# finally check if last best model is more recent than checkpoint
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
last_models["checkpoint"] = last_models["best_model"]
return last_models["checkpoint"], last_models["best_model"]
def process_args(args, config=None):
"""Process parsed comand line arguments and initialize the config if not provided.
Args:
args (argparse.Namespace or dict like): Parsed input arguments.
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
Returns:
c (TTS.utils.io.AttrDict): Config paramaters.
out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
logging to the console.
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
TODO:
- Interactive config definition.
"""
if isinstance(args, tuple):
args, coqpit_overrides = args
if args.continue_path:
# continue a previous training from its output folder
experiment_path = args.continue_path
args.config_path = os.path.join(args.continue_path, "config.json")
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model
# init config if not already defined
if config is None:
if args.config_path:
# init from a file
config = load_config(args.config_path)
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
if isinstance(self.optimizer, list):
target_avg_loss = 0
for idx in range(len(self.optimizer)):
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
target_avg_loss /= len(self.optimizer)
else:
# init from console args
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
target_avg_loss = keep_avg_target["avg_loss"]
return target_avg_loss
config_base = BaseTrainingConfig()
config_base.parse_known_args(coqpit_overrides)
config = register_config(config_base.model)()
# override values from command-line args
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
experiment_path = args.continue_path
if not experiment_path:
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
audio_path = os.path.join(experiment_path, "test_audios")
config.output_log_path = experiment_path
# setup rank 0 process in distributed training
dashboard_logger = None
if args.rank == 0:
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
# if model characters are not set in the config file
# save the default set to the config file for future
# compatibility.
if config.has("characters") and config.characters is None:
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
dashboard_logger = init_dashboard_logger(config)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, dashboard_logger
def _setup_logger_config(self, log_file: str) -> None:
"""Write log strings to a file and print logs to the terminal.
TODO: Causes formatting issues in pdb debugging."""
class Logger(object):
def __init__(self, print_to_terminal=True):
self.print_to_terminal = print_to_terminal
self.terminal = sys.stdout
self.log_file = log_file
def init_arguments():
train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser
def write(self, message):
if self.print_to_terminal:
self.terminal.write(message)
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(message)
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
"""Initialization of a training run."""
if isinstance(argv, Coqpit):
parser = argv.init_argparse(arg_prefix="")
else:
parser = init_arguments()
args = parser.parse_known_args()
config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger
# don't let processes rank > 0 write to the terminal
sys.stdout = Logger(self.args.rank == 0)
@staticmethod
def _is_apex_available() -> bool:
"""Check if Nvidia's APEX is available."""
return importlib.util.find_spec("apex") is not None

View File

@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.tts.datasets import TTSDataset
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
@ -32,6 +32,30 @@ class BaseTTS(BaseModel):
- 1D tensors `batch x 1`
"""
def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type.
If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args
If the config is for the model with a name like "*Args", then we assign the directly.
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
if "characters" in config:
_, self.config, num_chars = self.get_characters(config)
self.config.num_chars = num_chars
if hasattr(self.config, "model_args"):
config.model_args.num_chars = num_chars
self.args = self.config.model_args
else:
self.config = config
self.args = config.model_args
elif "Args" in config.__class__.__name__:
self.args = config
else:
raise ValueError("config must be either a *Config or *Args")
@staticmethod
def get_characters(config: Coqpit) -> str:
# TODO: implement CharacterProcessor
@ -169,7 +193,7 @@ class BaseTTS(BaseModel):
def get_data_loader(
self,
config: Coqpit,
ap: AudioProcessor,
assets: Dict,
is_eval: bool,
data_items: List,
verbose: bool,
@ -179,6 +203,8 @@ class BaseTTS(BaseModel):
if is_eval and not config.run_eval:
loader = None
else:
ap = assets["audio_processor"]
# setup multi-speaker attributes
if hasattr(self, "speaker_manager"):
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
@ -280,14 +306,18 @@ class BaseTTS(BaseModel):
)
return loader
def test_run(self, ap) -> Tuple[Dict, Dict]:
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
Args:
assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
ap = assets["audio_processor"]
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}

View File

@ -1,8 +1,13 @@
import importlib
import os
import re
from typing import Dict, List, Tuple
from urllib.parse import urlparse
import fsspec
import torch
from TTS.utils.io import load_fsspec
from TTS.utils.training import NoamLR
@ -80,3 +85,66 @@ def get_optimizer(
if model is not None:
parameters = model.parameters()
return optimizer(parameters, lr=lr, **optimizer_params)
def get_last_checkpoint(path: str) -> Tuple[str, str]:
"""Get latest checkpoint or/and best model in path.
It is based on globbing for `*.pth.tar` and the RegEx
`(checkpoint|best_model)_([0-9]+)`.
Args:
path: Path to files to be compared.
Raises:
ValueError: If no checkpoint or best_model files are found.
Returns:
Path to the last checkpoint
Path to best checkpoint
"""
fs = fsspec.get_mapper(path).fs
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
scheme = urlparse(path).scheme
if scheme: # scheme is not preserved in fs.glob, add it back
file_names = [scheme + "://" + file_name for file_name in file_names]
last_models = {}
last_model_nums = {}
for key in ["checkpoint", "best_model"]:
last_model_num = None
last_model = None
# pass all the checkpoint files and find
# the one with the largest model number suffix.
for file_name in file_names:
match = re.search(f"{key}_([0-9]+)", file_name)
if match is not None:
model_num = int(match.groups()[0])
if last_model_num is None or model_num > last_model_num:
last_model_num = model_num
last_model = file_name
# if there is no checkpoint found above
# find the checkpoint with the latest
# modification date.
key_file_names = [fn for fn in file_names if key in fn]
if last_model is None and len(key_file_names) > 0:
last_model = max(key_file_names, key=os.path.getctime)
last_model_num = load_fsspec(last_model)["step"]
if last_model is not None:
last_models[key] = last_model
last_model_nums[key] = last_model_num
# check what models were found
if not last_models:
raise ValueError(f"No models found in continue path {path}!")
if "checkpoint" not in last_models: # no checkpoint just best model
last_models["checkpoint"] = last_models["best_model"]
elif "best_model" not in last_models: # no best model
# this shouldn't happen, but let's handle it just in case
last_models["best_model"] = last_models["checkpoint"]
# finally check if last best model is more recent than checkpoint
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
last_models["checkpoint"] = last_models["best_model"]
return last_models["checkpoint"], last_models["best_model"]

View File

@ -1,3 +1,5 @@
from coqpit import Coqpit
from TTS.model import BaseModel
# pylint: skip-file
@ -16,5 +18,35 @@ class BaseVocoder(BaseModel):
- 1D tensors `batch x 1`
"""
def __init__(self):
super().__init__()
def __init__(self, config):
super().__init__(config)
def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type.
If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args
If the config is for the model with a name like "*Args", then we assign the directly.
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
if "characters" in config:
_, self.config, num_chars = self.get_characters(config)
self.config.num_chars = num_chars
if hasattr(self.config, "model_args"):
config.model_args.num_chars = num_chars
if "model_args" in config:
self.args = self.config.model_args
# This is for backward compatibility
if "model_params" in config:
self.args = self.config.model_params
else:
self.config = config
if "model_args" in config:
self.args = self.config.model_args
# This is for backward compatibility
if "model_params" in config:
self.args = self.config.model_params
else:
raise ValueError("config must be either a *Config or *Args")