mirror of https://github.com/coqui-ai/TTS.git
Refactor `trainer.py` for v2
This commit is contained in:
parent
7f388f26e3
commit
8ada870a57
537
TTS/trainer.py
537
TTS/trainer.py
|
@ -4,16 +4,14 @@ import importlib
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Tuple, Union
|
from inspect import signature
|
||||||
from urllib.parse import urlparse
|
from typing import Callable, Dict, List, Tuple, Union
|
||||||
|
|
||||||
import fsspec
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
@ -21,11 +19,7 @@ from torch import nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from TTS.config import load_config, register_config
|
from TTS.stt.datasets.tokenizer import Tokenizer
|
||||||
from TTS.tts.datasets import load_meta_data
|
|
||||||
from TTS.tts.models import setup_model as setup_tts_model
|
|
||||||
from TTS.tts.utils.text.symbols import parse_symbols
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
from TTS.utils.callbacks import TrainerCallback
|
from TTS.utils.callbacks import TrainerCallback
|
||||||
from TTS.utils.distribute import init_distributed
|
from TTS.utils.distribute import init_distributed
|
||||||
from TTS.utils.generic_utils import (
|
from TTS.utils.generic_utils import (
|
||||||
|
@ -39,9 +33,13 @@ from TTS.utils.generic_utils import (
|
||||||
)
|
)
|
||||||
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
|
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
from TTS.utils.trainer_utils import (
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
get_last_checkpoint,
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
get_optimizer,
|
||||||
|
get_scheduler,
|
||||||
|
is_apex_available,
|
||||||
|
setup_torch_training_env,
|
||||||
|
)
|
||||||
|
|
||||||
multiprocessing.set_start_method("fork")
|
multiprocessing.set_start_method("fork")
|
||||||
|
|
||||||
|
@ -80,6 +78,9 @@ class TrainingArgs(Coqpit):
|
||||||
"help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
|
"help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
skip_train_epoch: bool = field(
|
||||||
|
default=False, metadata={"help": "Run only evaluation iteration. Useful for debugging."}
|
||||||
|
)
|
||||||
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
||||||
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
|
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
|
||||||
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
|
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
|
||||||
|
@ -98,7 +99,14 @@ class Trainer:
|
||||||
c_logger: ConsoleLogger = None,
|
c_logger: ConsoleLogger = None,
|
||||||
dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
|
dashboard_logger: Union[TensorboardLogger, WandbLogger] = None,
|
||||||
model: nn.Module = None,
|
model: nn.Module = None,
|
||||||
|
get_model: Callable = None,
|
||||||
|
get_data_samples: Callable = None,
|
||||||
|
train_samples: List = None,
|
||||||
|
eval_samples: List = None,
|
||||||
|
tokenizer: Tokenizer = None,
|
||||||
cudnn_benchmark: bool = False,
|
cudnn_benchmark: bool = False,
|
||||||
|
training_assets: Dict = {},
|
||||||
|
parse_command_line_args: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Simple yet powerful 🐸💬 TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models
|
"""Simple yet powerful 🐸💬 TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models
|
||||||
or easily be customized.
|
or easily be customized.
|
||||||
|
@ -127,24 +135,44 @@ class Trainer:
|
||||||
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
|
model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer`
|
||||||
initializes a model from the provided config. Defaults to None.
|
initializes a model from the provided config. Defaults to None.
|
||||||
|
|
||||||
|
get_model (Callable):
|
||||||
|
A function that returns a model. It is used to initialize the model when `model` is not provided.
|
||||||
|
It either takes the config as the only argument or does not take any argument.
|
||||||
|
Defaults to None
|
||||||
|
|
||||||
|
get_data_samples (Callable):
|
||||||
|
A function that returns a list of training and evaluation samples. Used if `train_samples` and
|
||||||
|
`eval_samples` are None. Defaults to None.
|
||||||
|
|
||||||
|
train_samples (List):
|
||||||
|
A list of training samples used by the model's `get_data_loader` to init the `dataset` and the
|
||||||
|
`data_loader`. Defaults to None.
|
||||||
|
|
||||||
|
eval_samples (List):
|
||||||
|
A list of evaluation samples used by the model's `get_data_loader` to init the `dataset` and the
|
||||||
|
`data_loader`. Defaults to None.
|
||||||
|
|
||||||
cudnn_benchmark (bool): enable/disable PyTorch cudnn benchmarking. It is better to disable if the model input
|
cudnn_benchmark (bool): enable/disable PyTorch cudnn benchmarking. It is better to disable if the model input
|
||||||
length is changing batch to batch along the training.
|
length is changing batch to batch along the training.
|
||||||
|
|
||||||
|
training_assets (Dict):
|
||||||
|
A dictionary of assets to be used at training and passed to the model's ```train_log(), eval_log(), get_data_loader()```
|
||||||
|
during training. It can include `AudioProcessor` or/and `Tokenizer`. Defaults to {}.
|
||||||
|
|
||||||
|
parse_command_line_args (bool):
|
||||||
|
If true, parse command-line arguments and update `TrainingArgs` and model `config` values. Set it
|
||||||
|
to false if you parse the arguments yourself. Defaults to True.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
Running trainer on a model.
|
Running trainer with HifiGAN model.
|
||||||
|
|
||||||
>>> args = TrainingArgs(...)
|
>>> args = TrainingArgs(...)
|
||||||
>>> config = HifiganConfig(...)
|
>>> config = HifiganConfig(...)
|
||||||
>>> model = GANModel(config)
|
>>> model = GANModel(config)
|
||||||
>>> trainer = Trainer(args, config, output_path, model=model)
|
>>> ap = AudioProcessor(**config.audio)
|
||||||
>>> trainer.fit()
|
>>> assets = {"audio_processor": ap}
|
||||||
|
>>> trainer = Trainer(args, config, output_path, model=model, training_assets=assets)
|
||||||
Running trainer on a config.
|
|
||||||
|
|
||||||
>>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,)
|
|
||||||
>>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
|
||||||
>>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
|
||||||
>>> trainer.fit()
|
>>> trainer.fit()
|
||||||
|
|
||||||
TODO:
|
TODO:
|
||||||
|
@ -154,20 +182,33 @@ class Trainer:
|
||||||
- Profiler integration.
|
- Profiler integration.
|
||||||
- Overfitting to a batch.
|
- Overfitting to a batch.
|
||||||
- TPU training
|
- TPU training
|
||||||
|
- NOTE: Consider moving `training_assets` to the model implementation.
|
||||||
"""
|
"""
|
||||||
|
if parse_command_line_args:
|
||||||
|
# parse command-line arguments for TrainingArgs()
|
||||||
|
args, coqpit_overrides = self.parse_argv(args)
|
||||||
|
|
||||||
if config is None:
|
# get ready for training and parse command-line arguments for the model config
|
||||||
# parse config from console arguments
|
config = self.init_training(args, coqpit_overrides, config)
|
||||||
config, output_path, _, c_logger, dashboard_logger = process_args(args)
|
|
||||||
|
|
||||||
|
# define the experiment path and create the folder
|
||||||
|
output_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
# copy training assets to the output folder
|
||||||
|
copy_model_files(config, output_path, new_fields=None)
|
||||||
|
|
||||||
|
# init class members
|
||||||
self.args = args
|
self.args = args
|
||||||
self.config = config
|
self.config = config
|
||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
self.config.output_log_path = output_path
|
self.config.output_log_path = output_path
|
||||||
|
self.training_assets = training_assets
|
||||||
|
|
||||||
# setup logging
|
# setup logging
|
||||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||||
self._setup_logger_config(log_file)
|
self._setup_logger_config(log_file)
|
||||||
|
time.sleep(1.0) # wait for the logger to be ready
|
||||||
|
|
||||||
# set and initialize Pytorch runtime
|
# set and initialize Pytorch runtime
|
||||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
|
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
|
||||||
|
@ -196,33 +237,30 @@ class Trainer:
|
||||||
self.use_apex = self._is_apex_available()
|
self.use_apex = self._is_apex_available()
|
||||||
self.use_amp_scaler = self.config.mixed_precision and self.use_cuda
|
self.use_amp_scaler = self.config.mixed_precision and self.use_cuda
|
||||||
|
|
||||||
# init audio processor
|
# init tokenizer
|
||||||
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
# load data samples
|
# load data samples
|
||||||
# TODO: refactor this
|
if train_samples is None and get_data_samples is None:
|
||||||
if "datasets" in self.config:
|
raise ValueError("[!] `train_samples` and `get_data_samples` cannot both be None.")
|
||||||
# load data for `tts` models
|
if train_samples is not None:
|
||||||
self.data_train, self.data_eval = load_meta_data(self.config.datasets)
|
self.train_samples = train_samples
|
||||||
elif self.config.feature_path is not None:
|
self.eval_samples = eval_samples
|
||||||
# load pre-comnputed features for `vocoder`models
|
|
||||||
print(f" > Loading features from: {self.config.feature_path}")
|
|
||||||
self.data_eval, self.data_train = load_wav_feat_data(
|
|
||||||
self.config.data_path, self.config.feature_path, self.config.eval_split_size
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# load data for `vocoder`models
|
self.train_samples, self.eval_samples = self.run_get_data_samples(config, get_data_samples)
|
||||||
self.data_eval, self.data_train = load_wav_data(self.config.data_path, self.config.eval_split_size)
|
|
||||||
|
|
||||||
# init TTS model
|
# init TTS model
|
||||||
|
if model is None and get_model is None:
|
||||||
|
raise ValueError("[!] `model` and `get_model` cannot both be None.")
|
||||||
if model is not None:
|
if model is not None:
|
||||||
self.model = model
|
self.model = model
|
||||||
else:
|
else:
|
||||||
self.model = self.get_model(self.config)
|
self.run_get_model(self.config, get_model)
|
||||||
|
|
||||||
|
# TODO: out!
|
||||||
# init multispeaker settings of the model
|
# init multispeaker settings of the model
|
||||||
if hasattr(self.model, "init_multispeaker"):
|
if hasattr(self.model, "init_multispeaker"):
|
||||||
self.model.init_multispeaker(self.config, self.data_train + self.data_eval)
|
self.model.init_multispeaker(self.config, self.train_samples + self.eval_samples)
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
self.criterion = self.get_criterion(self.model)
|
self.criterion = self.get_criterion(self.model)
|
||||||
|
@ -247,7 +285,7 @@ class Trainer:
|
||||||
# setup optimizer
|
# setup optimizer
|
||||||
self.optimizer = self.get_optimizer(self.model, self.config)
|
self.optimizer = self.get_optimizer(self.model, self.config)
|
||||||
|
|
||||||
# callback
|
# CALLBACK
|
||||||
self.callbacks = TrainerCallback(self)
|
self.callbacks = TrainerCallback(self)
|
||||||
self.callbacks.on_init_start()
|
self.callbacks.on_init_start()
|
||||||
|
|
||||||
|
@ -280,7 +318,7 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
self.scheduler.last_epoch = self.restore_step
|
self.scheduler.last_epoch = self.restore_step
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRIBUTED
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
|
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
|
||||||
|
|
||||||
|
@ -291,8 +329,54 @@ class Trainer:
|
||||||
self.callbacks.on_init_end()
|
self.callbacks.on_init_end()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model(config: Coqpit) -> nn.Module:
|
def parse_argv(args: Union[Coqpit, List]):
|
||||||
"""Initialize model from config.
|
"""Parse command line arguments to init or override `TrainingArgs()`."""
|
||||||
|
if isinstance(args, Coqpit):
|
||||||
|
parser = args.init_argparse(arg_prefix="")
|
||||||
|
else:
|
||||||
|
train_config = TrainingArgs()
|
||||||
|
parser = train_config.init_argparse(arg_prefix="")
|
||||||
|
training_args, coqpit_overrides = parser.parse_known_args()
|
||||||
|
args.parse_args(training_args)
|
||||||
|
return args, coqpit_overrides
|
||||||
|
|
||||||
|
def init_training(self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None):
|
||||||
|
"""Initialize training and update model configs from command line arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||||
|
config_overrides (argparse.Namespace or dict like): Parsed config overriding arguments.
|
||||||
|
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
c (TTS.utils.io.AttrDict): Config paramaters.
|
||||||
|
"""
|
||||||
|
# set arguments for continuing training
|
||||||
|
if args.continue_path:
|
||||||
|
experiment_path = args.continue_path
|
||||||
|
args.config_path = os.path.join(args.continue_path, "config.json")
|
||||||
|
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||||
|
if not args.best_path:
|
||||||
|
args.best_path = best_model
|
||||||
|
|
||||||
|
# override config values from command-line args
|
||||||
|
# TODO: Maybe it is better to do it outside
|
||||||
|
if len(coqpit_overrides) > 0:
|
||||||
|
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||||
|
experiment_path = args.continue_path
|
||||||
|
|
||||||
|
# update the config.json fields and copy it to the output folder
|
||||||
|
if args.rank == 0:
|
||||||
|
new_fields = {}
|
||||||
|
if args.restore_path:
|
||||||
|
new_fields["restore_path"] = args.restore_path
|
||||||
|
new_fields["github_branch"] = get_git_branch()
|
||||||
|
copy_model_files(config, experiment_path, new_fields)
|
||||||
|
return config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run_get_model(config: Coqpit, get_model: Callable) -> nn.Module:
|
||||||
|
"""Run the `get_model` function and return the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model config.
|
config (Coqpit): Model config.
|
||||||
|
@ -300,12 +384,23 @@ class Trainer:
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: initialized model.
|
nn.Module: initialized model.
|
||||||
"""
|
"""
|
||||||
try:
|
if len(signature(get_model).sig.parameters) == 1:
|
||||||
model = setup_vocoder_model(config)
|
model = get_model(config)
|
||||||
except ModuleNotFoundError:
|
else:
|
||||||
model = setup_tts_model(config)
|
model = get_model()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module:
|
||||||
|
if isinstance(get_data_samples, Callable):
|
||||||
|
if len(signature(get_data_samples).sig.parameters) == 1:
|
||||||
|
train_samples, eval_samples = get_data_samples(config)
|
||||||
|
else:
|
||||||
|
train_samples, eval_samples = get_data_samples()
|
||||||
|
return train_samples, eval_samples
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
|
||||||
def restore_model(
|
def restore_model(
|
||||||
self,
|
self,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
|
@ -366,11 +461,15 @@ class Trainer:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return model, optimizer, scaler, restore_step
|
return model, optimizer, scaler, restore_step
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# DATA LOADING FUNCTIONS
|
||||||
|
#########################
|
||||||
|
|
||||||
def _get_loader(
|
def _get_loader(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
ap: AudioProcessor,
|
assets: Dict,
|
||||||
is_eval: bool,
|
is_eval: bool,
|
||||||
data_items: List,
|
data_items: List,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
|
@ -379,14 +478,14 @@ class Trainer:
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
if hasattr(model.module, "get_data_loader"):
|
if hasattr(model.module, "get_data_loader"):
|
||||||
loader = model.module.get_data_loader(
|
loader = model.module.get_data_loader(
|
||||||
config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank
|
config, assets, is_eval, data_items, verbose, num_gpus, self.args.rank
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if hasattr(model, "get_data_loader"):
|
if hasattr(model, "get_data_loader"):
|
||||||
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
|
loader = model.get_data_loader(config, assets, is_eval, data_items, verbose, num_gpus)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
|
def get_train_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader:
|
||||||
"""Initialize and return a training data loader.
|
"""Initialize and return a training data loader.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -397,10 +496,10 @@ class Trainer:
|
||||||
Returns:
|
Returns:
|
||||||
DataLoader: Initialized training data loader.
|
DataLoader: Initialized training data loader.
|
||||||
"""
|
"""
|
||||||
return self._get_loader(self.model, self.config, ap, False, data_items, verbose, self.num_gpus)
|
return self._get_loader(self.model, self.config, training_assets, False, data_items, verbose, self.num_gpus)
|
||||||
|
|
||||||
def get_eval_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
|
def get_eval_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader:
|
||||||
return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus)
|
return self._get_loader(self.model, self.config, training_assets, True, data_items, verbose, self.num_gpus)
|
||||||
|
|
||||||
def format_batch(self, batch: List) -> Dict:
|
def format_batch(self, batch: List) -> Dict:
|
||||||
"""Format the dataloader output and return a batch.
|
"""Format the dataloader output and return a batch.
|
||||||
|
@ -420,6 +519,10 @@ class Trainer:
|
||||||
batch[k] = to_cuda(v)
|
batch[k] = to_cuda(v)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
######################
|
||||||
|
# TRAIN FUNCTIONS
|
||||||
|
######################
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def master_params(optimizer: torch.optim.Optimizer):
|
def master_params(optimizer: torch.optim.Optimizer):
|
||||||
"""Generator over parameters owned by the optimizer.
|
"""Generator over parameters owned by the optimizer.
|
||||||
|
@ -567,24 +670,6 @@ class Trainer:
|
||||||
loss_dict["grad_norm"] = grad_norm
|
loss_dict["grad_norm"] = grad_norm
|
||||||
return outputs, loss_dict, step_time
|
return outputs, loss_dict, step_time
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _detach_loss_dict(loss_dict: Dict) -> Dict:
|
|
||||||
"""Detach loss values from autograp.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loss_dict (Dict): losses.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: losses detached from autograph.
|
|
||||||
"""
|
|
||||||
loss_dict_detached = {}
|
|
||||||
for key, value in loss_dict.items():
|
|
||||||
if isinstance(value, (int, float)):
|
|
||||||
loss_dict_detached[key] = value
|
|
||||||
else:
|
|
||||||
loss_dict_detached[key] = value.item()
|
|
||||||
return loss_dict_detached
|
|
||||||
|
|
||||||
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
||||||
"""Perform a training step on a batch of inputs and log the process.
|
"""Perform a training step on a batch of inputs and log the process.
|
||||||
|
|
||||||
|
@ -700,15 +785,14 @@ class Trainer:
|
||||||
self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
|
self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases)
|
||||||
|
|
||||||
# training visualizations
|
# training visualizations
|
||||||
figures, audios = None, None
|
|
||||||
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
|
if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"):
|
||||||
figures, audios = self.model.module.train_log(self.ap, batch, outputs)
|
self.model.module.train_log(
|
||||||
|
batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
|
||||||
|
)
|
||||||
elif hasattr(self.model, "train_log"):
|
elif hasattr(self.model, "train_log"):
|
||||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
self.model.train_log(
|
||||||
if figures is not None:
|
batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
|
||||||
self.dashboard_logger.train_figures(self.total_steps_done, figures)
|
)
|
||||||
if audios is not None:
|
|
||||||
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
|
||||||
|
|
||||||
self.dashboard_logger.flush()
|
self.dashboard_logger.flush()
|
||||||
|
|
||||||
|
@ -718,11 +802,13 @@ class Trainer:
|
||||||
|
|
||||||
def train_epoch(self) -> None:
|
def train_epoch(self) -> None:
|
||||||
"""Main entry point for the training loop. Run training on the all training samples."""
|
"""Main entry point for the training loop. Run training on the all training samples."""
|
||||||
|
# initialize the data loader
|
||||||
self.train_loader = self.get_train_dataloader(
|
self.train_loader = self.get_train_dataloader(
|
||||||
self.ap,
|
self.training_assets,
|
||||||
self.data_train,
|
self.train_samples,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
# set model to training mode
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
self.model.module.train()
|
self.model.module.train()
|
||||||
else:
|
else:
|
||||||
|
@ -734,11 +820,12 @@ class Trainer:
|
||||||
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
|
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
|
||||||
self.c_logger.print_train_start()
|
self.c_logger.print_train_start()
|
||||||
loader_start_time = time.time()
|
loader_start_time = time.time()
|
||||||
|
# iterate over the training samples
|
||||||
for cur_step, batch in enumerate(self.train_loader):
|
for cur_step, batch in enumerate(self.train_loader):
|
||||||
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
|
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
|
||||||
loader_start_time = time.time()
|
loader_start_time = time.time()
|
||||||
epoch_time = time.time() - epoch_start_time
|
epoch_time = time.time() - epoch_start_time
|
||||||
# Plot self.epochs_done Stats
|
# plot self.epochs_done Stats
|
||||||
if self.args.rank == 0:
|
if self.args.rank == 0:
|
||||||
epoch_stats = {"epoch_time": epoch_time}
|
epoch_stats = {"epoch_time": epoch_time}
|
||||||
epoch_stats.update(self.keep_avg_train.avg_values)
|
epoch_stats.update(self.keep_avg_train.avg_values)
|
||||||
|
@ -754,6 +841,10 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
|
||||||
|
#######################
|
||||||
|
# EVAL FUNCTIONS
|
||||||
|
#######################
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _model_eval_step(
|
def _model_eval_step(
|
||||||
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
|
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
|
||||||
|
@ -819,8 +910,8 @@ class Trainer:
|
||||||
"""Main entry point for the evaluation loop. Run evaluation on the all validation samples."""
|
"""Main entry point for the evaluation loop. Run evaluation on the all validation samples."""
|
||||||
self.eval_loader = (
|
self.eval_loader = (
|
||||||
self.get_eval_dataloader(
|
self.get_eval_dataloader(
|
||||||
self.ap,
|
self.training_assets,
|
||||||
self.data_eval,
|
self.eval_samples,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
if self.config.run_eval
|
if self.config.run_eval
|
||||||
|
@ -840,15 +931,12 @@ class Trainer:
|
||||||
loader_start_time = time.time()
|
loader_start_time = time.time()
|
||||||
# plot epoch stats, artifacts and figures
|
# plot epoch stats, artifacts and figures
|
||||||
if self.args.rank == 0:
|
if self.args.rank == 0:
|
||||||
figures, audios = None, None
|
|
||||||
if hasattr(self.model, "module") and hasattr(self.model.module, "eval_log"):
|
if hasattr(self.model, "module") and hasattr(self.model.module, "eval_log"):
|
||||||
figures, audios = self.model.module.eval_log(self.ap, batch, outputs)
|
self.model.module.eval_log(
|
||||||
|
batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done
|
||||||
|
)
|
||||||
elif hasattr(self.model, "eval_log"):
|
elif hasattr(self.model, "eval_log"):
|
||||||
figures, audios = self.model.eval_log(self.ap, batch, outputs)
|
self.model.eval_log(batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done)
|
||||||
if figures is not None:
|
|
||||||
self.dashboard_logger.eval_figures(self.total_steps_done, figures)
|
|
||||||
if audios is not None:
|
|
||||||
self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
|
||||||
self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
||||||
|
|
||||||
def test_run(self) -> None:
|
def test_run(self) -> None:
|
||||||
|
@ -857,22 +945,22 @@ class Trainer:
|
||||||
if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
|
if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
|
||||||
if self.eval_loader is None:
|
if self.eval_loader is None:
|
||||||
self.eval_loader = self.get_eval_dataloader(
|
self.eval_loader = self.get_eval_dataloader(
|
||||||
self.ap,
|
self.training_assets,
|
||||||
self.data_eval,
|
self.eval_samples,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
figures, audios = self.model.module.test_run(self.ap, samples, None)
|
figures, audios = self.model.module.test_run(self.training_assets, samples, None)
|
||||||
else:
|
else:
|
||||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
figures, audios = self.model.test_run(self.training_assets, samples, None)
|
||||||
else:
|
else:
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
figures, audios = self.model.module.test_run(self.ap)
|
figures, audios = self.model.module.test_run(self.training_assets)
|
||||||
else:
|
else:
|
||||||
figures, audios = self.model.test_run(self.ap)
|
figures, audios = self.model.test_run(self.training_assets)
|
||||||
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||||
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
||||||
|
|
||||||
|
@ -886,6 +974,10 @@ class Trainer:
|
||||||
self.best_loss = ch["model_loss"]
|
self.best_loss = ch["model_loss"]
|
||||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||||
|
|
||||||
|
###################################
|
||||||
|
# FIT FUNCTIONS
|
||||||
|
###################################
|
||||||
|
|
||||||
def _fit(self) -> None:
|
def _fit(self) -> None:
|
||||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||||
self._restore_best_loss()
|
self._restore_best_loss()
|
||||||
|
@ -901,6 +993,7 @@ class Trainer:
|
||||||
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||||
self.epochs_done = epoch
|
self.epochs_done = epoch
|
||||||
self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
|
self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
|
||||||
|
if not self.args.skip_train_epoch:
|
||||||
self.train_epoch()
|
self.train_epoch()
|
||||||
if self.config.run_eval:
|
if self.config.run_eval:
|
||||||
self.eval_epoch()
|
self.eval_epoch()
|
||||||
|
@ -939,24 +1032,6 @@ class Trainer:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
|
|
||||||
"""Pick the target loss to compare models"""
|
|
||||||
target_avg_loss = None
|
|
||||||
|
|
||||||
# return if target loss defined in the model config
|
|
||||||
if "target_loss" in self.config and self.config.target_loss:
|
|
||||||
return keep_avg_target[f"avg_{self.config.target_loss}"]
|
|
||||||
|
|
||||||
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
|
|
||||||
if isinstance(self.optimizer, list):
|
|
||||||
target_avg_loss = 0
|
|
||||||
for idx in range(len(self.optimizer)):
|
|
||||||
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
|
|
||||||
target_avg_loss /= len(self.optimizer)
|
|
||||||
else:
|
|
||||||
target_avg_loss = keep_avg_target["avg_loss"]
|
|
||||||
return target_avg_loss
|
|
||||||
|
|
||||||
def save_best_model(self) -> None:
|
def save_best_model(self) -> None:
|
||||||
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
|
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
|
||||||
|
|
||||||
|
@ -978,35 +1053,9 @@ class Trainer:
|
||||||
keep_after=self.config.keep_after,
|
keep_after=self.config.keep_after,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _setup_logger_config(self, log_file: str) -> None:
|
#####################
|
||||||
"""Write log strings to a file and print logs to the terminal.
|
# GET FUNCTIONS
|
||||||
TODO: Causes formatting issues in pdb debugging."""
|
#####################
|
||||||
|
|
||||||
class Logger(object):
|
|
||||||
def __init__(self, print_to_terminal=True):
|
|
||||||
self.print_to_terminal = print_to_terminal
|
|
||||||
self.terminal = sys.stdout
|
|
||||||
self.log_file = log_file
|
|
||||||
|
|
||||||
def write(self, message):
|
|
||||||
if self.print_to_terminal:
|
|
||||||
self.terminal.write(message)
|
|
||||||
with open(self.log_file, "a", encoding="utf-8") as f:
|
|
||||||
f.write(message)
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
# this flush method is needed for python 3 compatibility.
|
|
||||||
# this handles the flush command by doing nothing.
|
|
||||||
# you might want to specify some extra behavior here.
|
|
||||||
pass
|
|
||||||
|
|
||||||
# don't let processes rank > 0 write to the terminal
|
|
||||||
sys.stdout = Logger(self.args.rank == 0)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_apex_available() -> bool:
|
|
||||||
"""Check if Nvidia's APEX is available."""
|
|
||||||
return importlib.util.find_spec("apex") is not None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]:
|
def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]:
|
||||||
|
@ -1084,154 +1133,72 @@ class Trainer:
|
||||||
criterion = model.get_criterion()
|
criterion = model.get_criterion()
|
||||||
return criterion
|
return criterion
|
||||||
|
|
||||||
|
####################
|
||||||
|
# HELPER FUNCTIONS
|
||||||
|
####################
|
||||||
|
|
||||||
def getarguments():
|
@staticmethod
|
||||||
train_config = TrainingArgs()
|
def _detach_loss_dict(loss_dict: Dict) -> Dict:
|
||||||
parser = train_config.init_argparse(arg_prefix="")
|
"""Detach loss values from autograp.
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def get_last_checkpoint(path: str) -> Tuple[str, str]:
|
|
||||||
"""Get latest checkpoint or/and best model in path.
|
|
||||||
|
|
||||||
It is based on globbing for `*.pth.tar` and the RegEx
|
|
||||||
`(checkpoint|best_model)_([0-9]+)`.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Path to files to be compared.
|
loss_dict (Dict): losses.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no checkpoint or best_model files are found.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to the last checkpoint
|
Dict: losses detached from autograph.
|
||||||
Path to best checkpoint
|
|
||||||
"""
|
"""
|
||||||
fs = fsspec.get_mapper(path).fs
|
loss_dict_detached = {}
|
||||||
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
|
for key, value in loss_dict.items():
|
||||||
scheme = urlparse(path).scheme
|
if isinstance(value, (int, float)):
|
||||||
if scheme: # scheme is not preserved in fs.glob, add it back
|
loss_dict_detached[key] = value
|
||||||
file_names = [scheme + "://" + file_name for file_name in file_names]
|
|
||||||
last_models = {}
|
|
||||||
last_model_nums = {}
|
|
||||||
for key in ["checkpoint", "best_model"]:
|
|
||||||
last_model_num = None
|
|
||||||
last_model = None
|
|
||||||
# pass all the checkpoint files and find
|
|
||||||
# the one with the largest model number suffix.
|
|
||||||
for file_name in file_names:
|
|
||||||
match = re.search(f"{key}_([0-9]+)", file_name)
|
|
||||||
if match is not None:
|
|
||||||
model_num = int(match.groups()[0])
|
|
||||||
if last_model_num is None or model_num > last_model_num:
|
|
||||||
last_model_num = model_num
|
|
||||||
last_model = file_name
|
|
||||||
|
|
||||||
# if there is no checkpoint found above
|
|
||||||
# find the checkpoint with the latest
|
|
||||||
# modification date.
|
|
||||||
key_file_names = [fn for fn in file_names if key in fn]
|
|
||||||
if last_model is None and len(key_file_names) > 0:
|
|
||||||
last_model = max(key_file_names, key=os.path.getctime)
|
|
||||||
last_model_num = load_fsspec(last_model)["step"]
|
|
||||||
|
|
||||||
if last_model is not None:
|
|
||||||
last_models[key] = last_model
|
|
||||||
last_model_nums[key] = last_model_num
|
|
||||||
|
|
||||||
# check what models were found
|
|
||||||
if not last_models:
|
|
||||||
raise ValueError(f"No models found in continue path {path}!")
|
|
||||||
if "checkpoint" not in last_models: # no checkpoint just best model
|
|
||||||
last_models["checkpoint"] = last_models["best_model"]
|
|
||||||
elif "best_model" not in last_models: # no best model
|
|
||||||
# this shouldn't happen, but let's handle it just in case
|
|
||||||
last_models["best_model"] = last_models["checkpoint"]
|
|
||||||
# finally check if last best model is more recent than checkpoint
|
|
||||||
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
|
|
||||||
last_models["checkpoint"] = last_models["best_model"]
|
|
||||||
|
|
||||||
return last_models["checkpoint"], last_models["best_model"]
|
|
||||||
|
|
||||||
|
|
||||||
def process_args(args, config=None):
|
|
||||||
"""Process parsed comand line arguments and initialize the config if not provided.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args (argparse.Namespace or dict like): Parsed input arguments.
|
|
||||||
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
c (TTS.utils.io.AttrDict): Config paramaters.
|
|
||||||
out_path (str): Path to save models and logging.
|
|
||||||
audio_path (str): Path to save generated test audios.
|
|
||||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
|
||||||
logging to the console.
|
|
||||||
|
|
||||||
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
|
|
||||||
|
|
||||||
TODO:
|
|
||||||
- Interactive config definition.
|
|
||||||
"""
|
|
||||||
if isinstance(args, tuple):
|
|
||||||
args, coqpit_overrides = args
|
|
||||||
if args.continue_path:
|
|
||||||
# continue a previous training from its output folder
|
|
||||||
experiment_path = args.continue_path
|
|
||||||
args.config_path = os.path.join(args.continue_path, "config.json")
|
|
||||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
|
||||||
if not args.best_path:
|
|
||||||
args.best_path = best_model
|
|
||||||
# init config if not already defined
|
|
||||||
if config is None:
|
|
||||||
if args.config_path:
|
|
||||||
# init from a file
|
|
||||||
config = load_config(args.config_path)
|
|
||||||
else:
|
else:
|
||||||
# init from console args
|
loss_dict_detached[key] = value.detach()
|
||||||
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
return loss_dict_detached
|
||||||
|
|
||||||
config_base = BaseTrainingConfig()
|
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
|
||||||
config_base.parse_known_args(coqpit_overrides)
|
"""Pick the target loss to compare models"""
|
||||||
config = register_config(config_base.model)()
|
target_avg_loss = None
|
||||||
# override values from command-line args
|
|
||||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
|
||||||
experiment_path = args.continue_path
|
|
||||||
if not experiment_path:
|
|
||||||
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
|
||||||
audio_path = os.path.join(experiment_path, "test_audios")
|
|
||||||
config.output_log_path = experiment_path
|
|
||||||
# setup rank 0 process in distributed training
|
|
||||||
dashboard_logger = None
|
|
||||||
if args.rank == 0:
|
|
||||||
new_fields = {}
|
|
||||||
if args.restore_path:
|
|
||||||
new_fields["restore_path"] = args.restore_path
|
|
||||||
new_fields["github_branch"] = get_git_branch()
|
|
||||||
# if model characters are not set in the config file
|
|
||||||
# save the default set to the config file for future
|
|
||||||
# compatibility.
|
|
||||||
if config.has("characters") and config.characters is None:
|
|
||||||
used_characters = parse_symbols()
|
|
||||||
new_fields["characters"] = used_characters
|
|
||||||
copy_model_files(config, experiment_path, new_fields)
|
|
||||||
dashboard_logger = init_dashboard_logger(config)
|
|
||||||
c_logger = ConsoleLogger()
|
|
||||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
|
||||||
|
|
||||||
|
# return if target loss defined in the model config
|
||||||
|
if "target_loss" in self.config and self.config.target_loss:
|
||||||
|
return keep_avg_target[f"avg_{self.config.target_loss}"]
|
||||||
|
|
||||||
def init_arguments():
|
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
|
||||||
train_config = TrainingArgs()
|
if isinstance(self.optimizer, list):
|
||||||
parser = train_config.init_argparse(arg_prefix="")
|
target_avg_loss = 0
|
||||||
return parser
|
for idx in range(len(self.optimizer)):
|
||||||
|
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
|
||||||
|
target_avg_loss /= len(self.optimizer)
|
||||||
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
|
||||||
"""Initialization of a training run."""
|
|
||||||
if isinstance(argv, Coqpit):
|
|
||||||
parser = argv.init_argparse(arg_prefix="")
|
|
||||||
else:
|
else:
|
||||||
parser = init_arguments()
|
target_avg_loss = keep_avg_target["avg_loss"]
|
||||||
args = parser.parse_known_args()
|
return target_avg_loss
|
||||||
config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
|
|
||||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger
|
def _setup_logger_config(self, log_file: str) -> None:
|
||||||
|
"""Write log strings to a file and print logs to the terminal.
|
||||||
|
TODO: Causes formatting issues in pdb debugging."""
|
||||||
|
|
||||||
|
class Logger(object):
|
||||||
|
def __init__(self, print_to_terminal=True):
|
||||||
|
self.print_to_terminal = print_to_terminal
|
||||||
|
self.terminal = sys.stdout
|
||||||
|
self.log_file = log_file
|
||||||
|
|
||||||
|
def write(self, message):
|
||||||
|
if self.print_to_terminal:
|
||||||
|
self.terminal.write(message)
|
||||||
|
with open(self.log_file, "a", encoding="utf-8") as f:
|
||||||
|
f.write(message)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
# this flush method is needed for python 3 compatibility.
|
||||||
|
# this handles the flush command by doing nothing.
|
||||||
|
# you might want to specify some extra behavior here.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# don't let processes rank > 0 write to the terminal
|
||||||
|
sys.stdout = Logger(self.args.rank == 0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_apex_available() -> bool:
|
||||||
|
"""Check if Nvidia's APEX is available."""
|
||||||
|
return importlib.util.find_spec("apex") is not None
|
||||||
|
|
|
@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.model import BaseModel
|
from TTS.model import BaseModel
|
||||||
from TTS.tts.datasets import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text import make_symbols
|
from TTS.tts.utils.text import make_symbols
|
||||||
|
@ -32,6 +32,30 @@ class BaseTTS(BaseModel):
|
||||||
- 1D tensors `batch x 1`
|
- 1D tensors `batch x 1`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _set_model_args(self, config: Coqpit):
|
||||||
|
"""Setup model args based on the config type.
|
||||||
|
|
||||||
|
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||||
|
config.model_args
|
||||||
|
|
||||||
|
If the config is for the model with a name like "*Args", then we assign the directly.
|
||||||
|
"""
|
||||||
|
# don't use isintance not to import recursively
|
||||||
|
if "Config" in config.__class__.__name__:
|
||||||
|
if "characters" in config:
|
||||||
|
_, self.config, num_chars = self.get_characters(config)
|
||||||
|
self.config.num_chars = num_chars
|
||||||
|
if hasattr(self.config, "model_args"):
|
||||||
|
config.model_args.num_chars = num_chars
|
||||||
|
self.args = self.config.model_args
|
||||||
|
else:
|
||||||
|
self.config = config
|
||||||
|
self.args = config.model_args
|
||||||
|
elif "Args" in config.__class__.__name__:
|
||||||
|
self.args = config
|
||||||
|
else:
|
||||||
|
raise ValueError("config must be either a *Config or *Args")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_characters(config: Coqpit) -> str:
|
def get_characters(config: Coqpit) -> str:
|
||||||
# TODO: implement CharacterProcessor
|
# TODO: implement CharacterProcessor
|
||||||
|
@ -169,7 +193,7 @@ class BaseTTS(BaseModel):
|
||||||
def get_data_loader(
|
def get_data_loader(
|
||||||
self,
|
self,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
ap: AudioProcessor,
|
assets: Dict,
|
||||||
is_eval: bool,
|
is_eval: bool,
|
||||||
data_items: List,
|
data_items: List,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
|
@ -179,6 +203,8 @@ class BaseTTS(BaseModel):
|
||||||
if is_eval and not config.run_eval:
|
if is_eval and not config.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
|
ap = assets["audio_processor"]
|
||||||
|
|
||||||
# setup multi-speaker attributes
|
# setup multi-speaker attributes
|
||||||
if hasattr(self, "speaker_manager"):
|
if hasattr(self, "speaker_manager"):
|
||||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||||
|
@ -280,14 +306,18 @@ class BaseTTS(BaseModel):
|
||||||
)
|
)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||||
"""Generic test run for `tts` models used by `Trainer`.
|
"""Generic test run for `tts` models used by `Trainer`.
|
||||||
|
|
||||||
You can override this for a different behaviour.
|
You can override this for a different behaviour.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||||
"""
|
"""
|
||||||
|
ap = assets["audio_processor"]
|
||||||
print(" | > Synthesizing test sentences.")
|
print(" | > Synthesizing test sentences.")
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.training import NoamLR
|
from TTS.utils.training import NoamLR
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,3 +85,66 @@ def get_optimizer(
|
||||||
if model is not None:
|
if model is not None:
|
||||||
parameters = model.parameters()
|
parameters = model.parameters()
|
||||||
return optimizer(parameters, lr=lr, **optimizer_params)
|
return optimizer(parameters, lr=lr, **optimizer_params)
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_checkpoint(path: str) -> Tuple[str, str]:
|
||||||
|
"""Get latest checkpoint or/and best model in path.
|
||||||
|
|
||||||
|
It is based on globbing for `*.pth.tar` and the RegEx
|
||||||
|
`(checkpoint|best_model)_([0-9]+)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to files to be compared.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no checkpoint or best_model files are found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the last checkpoint
|
||||||
|
Path to best checkpoint
|
||||||
|
"""
|
||||||
|
fs = fsspec.get_mapper(path).fs
|
||||||
|
file_names = fs.glob(os.path.join(path, "*.pth.tar"))
|
||||||
|
scheme = urlparse(path).scheme
|
||||||
|
if scheme: # scheme is not preserved in fs.glob, add it back
|
||||||
|
file_names = [scheme + "://" + file_name for file_name in file_names]
|
||||||
|
last_models = {}
|
||||||
|
last_model_nums = {}
|
||||||
|
for key in ["checkpoint", "best_model"]:
|
||||||
|
last_model_num = None
|
||||||
|
last_model = None
|
||||||
|
# pass all the checkpoint files and find
|
||||||
|
# the one with the largest model number suffix.
|
||||||
|
for file_name in file_names:
|
||||||
|
match = re.search(f"{key}_([0-9]+)", file_name)
|
||||||
|
if match is not None:
|
||||||
|
model_num = int(match.groups()[0])
|
||||||
|
if last_model_num is None or model_num > last_model_num:
|
||||||
|
last_model_num = model_num
|
||||||
|
last_model = file_name
|
||||||
|
|
||||||
|
# if there is no checkpoint found above
|
||||||
|
# find the checkpoint with the latest
|
||||||
|
# modification date.
|
||||||
|
key_file_names = [fn for fn in file_names if key in fn]
|
||||||
|
if last_model is None and len(key_file_names) > 0:
|
||||||
|
last_model = max(key_file_names, key=os.path.getctime)
|
||||||
|
last_model_num = load_fsspec(last_model)["step"]
|
||||||
|
|
||||||
|
if last_model is not None:
|
||||||
|
last_models[key] = last_model
|
||||||
|
last_model_nums[key] = last_model_num
|
||||||
|
|
||||||
|
# check what models were found
|
||||||
|
if not last_models:
|
||||||
|
raise ValueError(f"No models found in continue path {path}!")
|
||||||
|
if "checkpoint" not in last_models: # no checkpoint just best model
|
||||||
|
last_models["checkpoint"] = last_models["best_model"]
|
||||||
|
elif "best_model" not in last_models: # no best model
|
||||||
|
# this shouldn't happen, but let's handle it just in case
|
||||||
|
last_models["best_model"] = last_models["checkpoint"]
|
||||||
|
# finally check if last best model is more recent than checkpoint
|
||||||
|
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
|
||||||
|
last_models["checkpoint"] = last_models["best_model"]
|
||||||
|
|
||||||
|
return last_models["checkpoint"], last_models["best_model"]
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from coqpit import Coqpit
|
||||||
|
|
||||||
from TTS.model import BaseModel
|
from TTS.model import BaseModel
|
||||||
|
|
||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
@ -16,5 +18,35 @@ class BaseVocoder(BaseModel):
|
||||||
- 1D tensors `batch x 1`
|
- 1D tensors `batch x 1`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
|
|
||||||
|
def _set_model_args(self, config: Coqpit):
|
||||||
|
"""Setup model args based on the config type.
|
||||||
|
|
||||||
|
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||||
|
config.model_args
|
||||||
|
|
||||||
|
If the config is for the model with a name like "*Args", then we assign the directly.
|
||||||
|
"""
|
||||||
|
# don't use isintance not to import recursively
|
||||||
|
if "Config" in config.__class__.__name__:
|
||||||
|
if "characters" in config:
|
||||||
|
_, self.config, num_chars = self.get_characters(config)
|
||||||
|
self.config.num_chars = num_chars
|
||||||
|
if hasattr(self.config, "model_args"):
|
||||||
|
config.model_args.num_chars = num_chars
|
||||||
|
if "model_args" in config:
|
||||||
|
self.args = self.config.model_args
|
||||||
|
# This is for backward compatibility
|
||||||
|
if "model_params" in config:
|
||||||
|
self.args = self.config.model_params
|
||||||
|
else:
|
||||||
|
self.config = config
|
||||||
|
if "model_args" in config:
|
||||||
|
self.args = self.config.model_args
|
||||||
|
# This is for backward compatibility
|
||||||
|
if "model_params" in config:
|
||||||
|
self.args = self.config.model_params
|
||||||
|
else:
|
||||||
|
raise ValueError("config must be either a *Config or *Args")
|
||||||
|
|
Loading…
Reference in New Issue