diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 3270d0e0..06765906 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -2,7 +2,7 @@ import os import sys import traceback -from TTS.trainer import TrainerTTS +from TTS.tts.trainer_tts import TrainerTTS from TTS.utils.arguments import init_training from TTS.utils.generic_utils import remove_experiment_folder diff --git a/TTS/trainer.py b/TTS/trainer.py index c1d1c340..5c02fdfb 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -1,39 +1,23 @@ # -*- coding: utf-8 -*- import importlib -import logging -import os -import time -from argparse import Namespace +from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple, TypeVar import torch from coqpit import Coqpit # DISTRIBUTED from torch import nn -from torch.nn.parallel import DistributedDataParallel as DDP_th -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from TTS.tts.datasets import TTSDataset, load_meta_data -from TTS.tts.layers import setup_loss -from TTS.tts.models import setup_model -from TTS.tts.utils.io import save_best_model, save_checkpoint -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager -from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.symbols import make_symbols -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.audio import AudioProcessor -from TTS.utils.distribute import init_distributed -from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict, to_cuda -from TTS.utils.logging import ConsoleLogger, TensorboardLogger -from TTS.utils.training import check_update, setup_torch_training_env +_DataLoader = TypeVar("_DataLoader") @dataclass class TrainingArgs(Coqpit): + """Trainer arguments that are parsed externally (e.g. CLI)""" + continue_path: str = field( default="", metadata={ @@ -58,676 +42,100 @@ class TrainingArgs(Coqpit): # pylint: disable=import-outside-toplevel, too-many-public-methods -class TrainerTTS: - use_cuda, num_gpus = setup_torch_training_env(True, False) - def __init__( - self, - args: Union[Coqpit, Namespace], - config: Coqpit, - c_logger: ConsoleLogger = None, - tb_logger: TensorboardLogger = None, - model: nn.Module = None, - output_path: str = None, - ) -> None: - self.args = args - self.config = config - self.c_logger = ConsoleLogger() if c_logger is None else c_logger - if tb_logger is None: - self.tb_logger = TensorboardLogger(output_path, model_name=config.model) - self.tb_logger.tb_add_text("model-config", f"
{config.to_json()}", 0) - else: - self.tb_logger = tb_logger - self.output_path = output_path - self.total_steps_done = 0 - self.epochs_done = 0 - self.restore_step = 0 - self.best_loss = float("inf") - self.train_loader = None - self.eval_loader = None - self.output_audio_path = os.path.join(output_path, "test_audios") - - self.keep_avg_train = None - self.keep_avg_eval = None - - log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") - self._setup_logger_config(log_file) - - # model, audio processor, datasets, loss - # init audio processor - self.ap = AudioProcessor(**self.config.audio.to_dict()) - - # init character processor - self.model_characters = self.get_character_processor(self.config) - - # load dataset samples - self.data_train, self.data_eval = load_meta_data(self.config.datasets) - - # default speaker manager - self.speaker_manager = self.get_speaker_manager(self.config, args.restore_path, output_path, self.data_train) - - # init TTS model - if model is not None: - self.model = model - else: - self.model = self.get_model( - len(self.model_characters), - self.speaker_manager.num_speakers, - self.config, - self.speaker_manager.d_vector_dim if self.speaker_manager.d_vectors else None, - ) - - # setup criterion - self.criterion = self.get_criterion(self.config) - - # DISTRUBUTED - if self.num_gpus > 1: - init_distributed( - args.rank, - self.num_gpus, - args.group_id, - self.config.distributed_backend, - self.config.distributed_url, - ) - - if self.use_cuda: - self.model.cuda() - self.criterion.cuda() - - # scalers for mixed precision training - self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None - - # setup optimizer - self.optimizer = self.get_optimizer(self.model, self.config) - - if self.args.restore_path: - self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model( - self.config, args.restore_path, self.model, self.optimizer, self.scaler - ) - - # setup scheduler - self.scheduler = self.get_scheduler(self.config, self.optimizer) - - # DISTRUBUTED - if self.num_gpus > 1: - self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) - - # count model size - num_params = count_parameters(self.model) - print("\n > Model has {} parameters".format(num_params)) +class TrainerAbstract(ABC): @staticmethod - def get_model(num_chars: int, num_speakers: int, config: Coqpit, d_vector_dim: int) -> nn.Module: - model = setup_model(num_chars, num_speakers, config, d_vector_dim) - return model + def _is_apex_available(): + return importlib.util.find_spec("apex") is not None @staticmethod + @abstractmethod + def get_model(*args, **kwargs) -> nn.Module: + pass + + @staticmethod + @abstractmethod def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer: - optimizer_name = config.optimizer - optimizer_params = config.optimizer_params - if optimizer_name.lower() == "radam": - module = importlib.import_module("TTS.utils.radam") - optimizer = getattr(module, "RAdam") - else: - optimizer = getattr(torch.optim, optimizer_name) - return optimizer(model.parameters(), lr=config.lr, **optimizer_params) - - @staticmethod - def get_character_processor(config: Coqpit) -> str: - # setup custom characters if set in config file. - # TODO: implement CharacterProcessor - if config.characters is not None: - symbols, phonemes = make_symbols(**config.characters.to_dict()) - else: - from TTS.tts.utils.text.symbols import phonemes, symbols - model_characters = phonemes if config.use_phonemes else symbols - return model_characters - - @staticmethod - def get_speaker_manager( - config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None - ) -> SpeakerManager: - speaker_manager = get_speaker_manager(config, restore_path, data_train, out_path) - return speaker_manager + pass @staticmethod + @abstractmethod def get_scheduler( config: Coqpit, optimizer: torch.optim.Optimizer ) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access - lr_scheduler = config.lr_scheduler - lr_scheduler_params = config.lr_scheduler_params - if lr_scheduler is None: - return None - if lr_scheduler.lower() == "noamlr": - from TTS.utils.training import NoamLR - - scheduler = NoamLR - else: - scheduler = getattr(torch.optim, lr_scheduler) - return scheduler(optimizer, **lr_scheduler_params) + pass @staticmethod + @abstractmethod def get_criterion(config: Coqpit) -> nn.Module: - return setup_loss(config) + pass - def restore_model( - self, - config: Coqpit, - restore_path: str, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scaler: torch.cuda.amp.GradScaler = None, - ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: - print(" > Restoring from %s ..." % os.path.basename(restore_path)) - checkpoint = torch.load(restore_path) - try: - print(" > Restoring Model...") - model.load_state_dict(checkpoint["model"]) - print(" > Restoring Optimizer...") - optimizer.load_state_dict(checkpoint["optimizer"]) - if "scaler" in checkpoint and config.mixed_precision: - print(" > Restoring AMP Scaler...") - scaler.load_state_dict(checkpoint["scaler"]) - except (KeyError, RuntimeError): - print(" > Partial model initialization...") - model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], config) - model.load_state_dict(model_dict) - del model_dict + @abstractmethod + def restore_model(self, *args, **kwargs) -> Tuple: + pass - for group in optimizer.param_groups: - group["lr"] = self.config.lr - print( - " > Model restored from step %d" % checkpoint["step"], - ) - restore_step = checkpoint["step"] - return model, optimizer, scaler, restore_step + @abstractmethod + def get_train_dataloader(self, *args, **kwargs) -> _DataLoader: + pass - def _get_loader( - self, - r: int, - ap: AudioProcessor, - is_eval: bool, - data_items: List, - verbose: bool, - speaker_ids: Union[Dict, List], - d_vectors: Union[Dict, List], - ) -> DataLoader: - if is_eval and not self.config.run_eval: - loader = None - else: - dataset = TTSDataset( - outputs_per_step=r, - text_cleaner=self.config.text_cleaner, - compute_linear_spec=self.config.model.lower() == "tacotron", - meta_data=data_items, - ap=ap, - tp=self.config.characters, - add_blank=self.config["add_blank"], - batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size, - min_seq_len=self.config.min_seq_len, - max_seq_len=self.config.max_seq_len, - phoneme_cache_path=self.config.phoneme_cache_path, - use_phonemes=self.config.use_phonemes, - phoneme_language=self.config.phoneme_language, - enable_eos_bos=self.config.enable_eos_bos_chars, - use_noise_augment=not is_eval, - verbose=verbose, - speaker_id_mapping=speaker_ids if self.config.use_speaker_embedding else None, - d_vector_mapping=d_vectors - if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file - else None, - ) - - if self.config.use_phonemes and self.config.compute_input_seq_cache: - # precompute phonemes to have a better estimate of sequence lengths. - dataset.compute_input_seq(self.config.num_loader_workers) - dataset.sort_items() - - sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None - loader = DataLoader( - dataset, - batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=False, - sampler=sampler, - num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers, - pin_memory=False, - ) - return loader - - def get_train_dataloader( - self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict - ) -> DataLoader: - return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors) - - def get_eval_dataloder( - self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict - ) -> DataLoader: - return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors) + @abstractmethod + def get_eval_dataloder(self, *args, **kwargs) -> _DataLoader: + pass + @abstractmethod def format_batch(self, batch: List) -> Dict: - # setup input batch - text_input = batch[0] - text_lengths = batch[1] - speaker_names = batch[2] - linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None - mel_input = batch[4] - mel_lengths = batch[5] - stop_targets = batch[6] - item_idx = batch[7] - d_vectors = batch[8] - speaker_ids = batch[9] - attn_mask = batch[10] - max_text_length = torch.max(text_lengths.float()) - max_spec_length = torch.max(mel_lengths.float()) - - # compute durations from attention masks - durations = None - if attn_mask is not None: - durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) - for idx, am in enumerate(attn_mask): - # compute raw durations - c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] - # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) - c_idxs, counts = torch.unique(c_idxs, return_counts=True) - dur = torch.ones([text_lengths[idx]]).to(counts.dtype) - dur[c_idxs] = counts - # smooth the durations and set any 0 duration to 1 - # by cutting off from the largest duration indeces. - extra_frames = dur.sum() - mel_lengths[idx] - largest_idxs = torch.argsort(-dur)[:extra_frames] - dur[largest_idxs] -= 1 - assert ( - dur.sum() == mel_lengths[idx] - ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" - durations[idx, : text_lengths[idx]] = dur - - # set stop targets view, we predict a single stop token per iteration. - stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) - - # dispatch batch to GPU - if self.use_cuda: - text_input = to_cuda(text_input) - text_lengths = to_cuda(text_lengths) - mel_input = to_cuda(mel_input) - mel_lengths = to_cuda(mel_lengths) - linear_input = to_cuda(linear_input) if self.config.model.lower() in ["tacotron"] else None - stop_targets = to_cuda(stop_targets) - attn_mask = to_cuda(attn_mask) if attn_mask is not None else None - durations = to_cuda(durations) if attn_mask is not None else None - if speaker_ids is not None: - speaker_ids = to_cuda(speaker_ids) - if d_vectors is not None: - d_vectors = to_cuda(d_vectors) - - return { - "text_input": text_input, - "text_lengths": text_lengths, - "speaker_names": speaker_names, - "mel_input": mel_input, - "mel_lengths": mel_lengths, - "linear_input": linear_input, - "stop_targets": stop_targets, - "attn_mask": attn_mask, - "durations": durations, - "speaker_ids": speaker_ids, - "d_vectors": d_vectors, - "max_text_length": max_text_length, - "max_spec_length": max_spec_length, - "item_idx": item_idx, - } + pass + @abstractmethod def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: - if hasattr(self.model, "module"): - return self.model.module.train_step(batch, criterion) - return self.model.train_step(batch, criterion) + pass + @abstractmethod def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: - self.on_train_step_start() - step_start_time = time.time() - - # format data - batch = self.format_batch(batch) - loader_time = time.time() - loader_start_time - - # zero-out optimizer - self.optimizer.zero_grad() - - with torch.cuda.amp.autocast(enabled=self.config.mixed_precision): - outputs, loss_dict = self._train_step(batch, self.criterion) - - # check nan loss - if torch.isnan(loss_dict["loss"]).any(): - raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.") - - # optimizer step - if self.config.mixed_precision: - # model optimizer step in mixed precision mode - self.scaler.scale(loss_dict["loss"]).backward() - self.scaler.unscale_(self.optimizer) - grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True) - self.scaler.step(self.optimizer) - self.scaler.update() - else: - # main model optimizer step - loss_dict["loss"].backward() - grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True) - self.optimizer.step() - - step_time = time.time() - step_start_time - - # setup lr - if self.config.lr_scheduler: - self.scheduler.step() - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - update_train_values["avg_loader_time"] = loader_time - update_train_values["avg_step_time"] = step_time - self.keep_avg_train.update_values(update_train_values) - - # print training progress - current_lr = self.optimizer.param_groups[0]["lr"] - if self.total_steps_done % self.config.print_step == 0: - log_dict = { - "max_spec_length": [batch["max_spec_length"], 1], # value, precision - "max_text_length": [batch["max_text_length"], 1], - "step_time": [step_time, 4], - "loader_time": [loader_time, 2], - "current_lr": current_lr, - } - self.c_logger.print_train_step( - batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values - ) - - if self.args.rank == 0: - # Plot Training Iter Stats - # reduce TB load - if self.total_steps_done % self.config.tb_plot_step == 0: - iter_stats = { - "lr": current_lr, - "grad_norm": grad_norm, - "step_time": step_time, - } - iter_stats.update(loss_dict) - self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats) - - if self.total_steps_done % self.config.save_step == 0: - if self.config.checkpoint: - # save model - save_checkpoint( - self.model, - self.optimizer, - self.total_steps_done, - self.epochs_done, - self.config.r, - self.output_path, - model_loss=loss_dict["loss"], - characters=self.model_characters, - scaler=self.scaler.state_dict() if self.config.mixed_precision else None, - ) - # training visualizations - if hasattr(self.model, "module"): - figures, audios = self.model.module.train_log(self.ap, batch, outputs) - else: - figures, audios = self.model.train_log(self.ap, batch, outputs) - self.tb_logger.tb_train_figures(self.total_steps_done, figures) - self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate) - self.total_steps_done += 1 - self.on_train_step_end() - return outputs, loss_dict + pass + @abstractmethod def train_epoch(self) -> None: - self.model.train() - epoch_start_time = time.time() - if self.use_cuda: - batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus)) - else: - batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) - self.c_logger.print_train_start() - loader_start_time = time.time() - for cur_step, batch in enumerate(self.train_loader): - _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) - epoch_time = time.time() - epoch_start_time - # Plot self.epochs_done Stats - if self.args.rank == 0: - epoch_stats = {"epoch_time": epoch_time} - epoch_stats.update(self.keep_avg_train.avg_values) - self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats) - if self.config.tb_model_param_stats: - self.tb_logger.tb_model_weights(self.model, self.total_steps_done) + pass + @abstractmethod def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]: - if hasattr(self.model, "module"): - return self.model.module.eval_step(batch, self.criterion) - return self.model.eval_step(batch, self.criterion) + pass + @abstractmethod def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: - with torch.no_grad(): - step_start_time = time.time() - - with torch.cuda.amp.autocast(enabled=self.config.mixed_precision): - outputs, loss_dict = self._eval_step(batch) - - step_time = time.time() - step_start_time - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_eval_values = dict() - for key, value in loss_dict.items(): - update_eval_values["avg_" + key] = value - update_eval_values["avg_step_time"] = step_time - self.keep_avg_eval.update_values(update_eval_values) - - if self.config.print_eval: - self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values) - return outputs, loss_dict + pass + @abstractmethod def eval_epoch(self) -> None: - self.model.eval() - self.c_logger.print_eval_start() - loader_start_time = time.time() - batch = None - for cur_step, batch in enumerate(self.eval_loader): - # format data - batch = self.format_batch(batch) - loader_time = time.time() - loader_start_time - self.keep_avg_eval.update_values({"avg_loader_time": loader_time}) - outputs, _ = self.eval_step(batch, cur_step) - # Plot epoch stats and samples from the last batch. - if self.args.rank == 0: - if hasattr(self.model, "module"): - figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs) - else: - figures, eval_audios = self.model.eval_log(self.ap, batch, outputs) - self.tb_logger.tb_eval_figures(self.total_steps_done, figures) - self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate) + pass - def test_run( - self, - ) -> None: - print(" | > Synthesizing test sentences.") - test_audios = {} - test_figures = {} - test_sentences = self.config.test_sentences - aux_inputs = self._get_aux_inputs() - for idx, sen in enumerate(test_sentences): - wav, alignment, model_outputs, _ = synthesis( - self.model, - sen, - self.config, - self.use_cuda, - self.ap, - speaker_id=aux_inputs["speaker_id"], - d_vector=aux_inputs["d_vector"], - style_wav=aux_inputs["style_wav"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, - use_griffin_lim=True, - do_trim_silence=False, - ).values() - - file_path = os.path.join(self.output_audio_path, str(self.total_steps_done)) - os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) - self.ap.save_wav(wav, file_path) - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False) - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) - - self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"]) - self.tb_logger.tb_test_figures(self.total_steps_done, test_figures) - - def _get_aux_inputs(self) -> Dict: - # setup speaker_id - speaker_id = 0 if self.config.use_speaker_embedding else None - # setup d_vector - d_vector = ( - self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0]) - if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding - else None - ) - # setup style_mel - if self.config.has("gst_style_input"): - style_wav = self.config.gst_style_input - else: - style_wav = None - if style_wav is None and "use_gst" in self.config and self.config.use_gst: - # inicialize GST with zero dict. - style_wav = {} - print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") - for i in range(self.config.gst["gst_num_style_tokens"]): - style_wav[str(i)] = 0 - aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector} - return aux_inputs + @abstractmethod + def test_run(self) -> None: + pass + @abstractmethod def fit(self) -> None: - if self.restore_step != 0 or self.args.best_path: - print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") - self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"] - print(f" > Starting with loaded last best loss {self.best_loss}.") - - # define data loaders - self.train_loader = self.get_train_dataloader( - self.config.r, - self.ap, - self.data_train, - verbose=True, - speaker_ids=self.speaker_manager.speaker_ids, - d_vectors=self.speaker_manager.d_vectors, - ) - self.eval_loader = ( - self.get_eval_dataloder( - self.config.r, - self.ap, - self.data_train, - verbose=True, - speaker_ids=self.speaker_manager.speaker_ids, - d_vectors=self.speaker_manager.d_vectors, - ) - if self.config.run_eval - else None - ) - - self.total_steps_done = self.restore_step - - for epoch in range(0, self.config.epochs): - self.on_epoch_start() - self.keep_avg_train = KeepAverage() - self.keep_avg_eval = KeepAverage() if self.config.run_eval else None - self.epochs_done = epoch - self.c_logger.print_epoch_start(epoch, self.config.epochs) - self.train_epoch() - if self.config.run_eval: - self.eval_epoch() - if epoch >= self.config.test_delay_epochs and self.args.rank < 0: - self.test_run() - self.c_logger.print_epoch_end( - epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values - ) - self.save_best_model() - self.on_epoch_end() + pass + @abstractmethod def save_best_model(self) -> None: - self.best_loss = save_best_model( - self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"], - self.best_loss, - self.model, - self.optimizer, - self.total_steps_done, - self.epochs_done, - self.config.r, - self.output_path, - self.model_characters, - keep_all_best=self.config.keep_all_best, - keep_after=self.config.keep_after, - scaler=self.scaler.state_dict() if self.config.mixed_precision else None, - ) + pass - @staticmethod - def _setup_logger_config(log_file: str) -> None: - logging.basicConfig( - level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()] - ) + @abstractmethod + def on_epoch_start(self) -> None: + pass - def on_epoch_start(self) -> None: # pylint: disable=no-self-use - if hasattr(self.model, "on_epoch_start"): - self.model.on_epoch_start(self) + @abstractmethod + def on_epoch_end(self) -> None: + pass - if hasattr(self.criterion, "on_epoch_start"): - self.criterion.on_epoch_start(self) + @abstractmethod + def on_train_step_start(self) -> None: + pass - if hasattr(self.optimizer, "on_epoch_start"): - self.optimizer.on_epoch_start(self) - - def on_epoch_end(self) -> None: # pylint: disable=no-self-use - if hasattr(self.model, "on_epoch_end"): - self.model.on_epoch_end(self) - - if hasattr(self.criterion, "on_epoch_end"): - self.criterion.on_epoch_end(self) - - if hasattr(self.optimizer, "on_epoch_end"): - self.optimizer.on_epoch_end(self) - - def on_train_step_start(self) -> None: # pylint: disable=no-self-use - if hasattr(self.model, "on_train_step_start"): - self.model.on_train_step_start(self) - - if hasattr(self.criterion, "on_train_step_start"): - self.criterion.on_train_step_start(self) - - if hasattr(self.optimizer, "on_train_step_start"): - self.optimizer.on_train_step_start(self) - - def on_train_step_end(self) -> None: # pylint: disable=no-self-use - if hasattr(self.model, "on_train_step_end"): - self.model.on_train_step_end(self) - - if hasattr(self.criterion, "on_train_step_end"): - self.criterion.on_train_step_end(self) - - if hasattr(self.optimizer, "on_train_step_end"): - self.optimizer.on_train_step_end(self) + @abstractmethod + def on_train_step_end(self) -> None: + pass diff --git a/TTS/tts/trainer_tts.py b/TTS/tts/trainer_tts.py new file mode 100644 index 00000000..9d060498 --- /dev/null +++ b/TTS/tts/trainer_tts.py @@ -0,0 +1,709 @@ +# -*- coding: utf-8 -*- + +import importlib +import logging +import os +import time +from argparse import Namespace +from typing import Dict, List, Tuple, Union + +import torch +from coqpit import Coqpit + +# DISTRIBUTED +from torch import nn +from torch.nn.parallel import DistributedDataParallel as DDP_th +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from TTS.trainer import TrainerAbstract +from TTS.tts.datasets import TTSDataset, load_meta_data +from TTS.tts.layers import setup_loss +from TTS.tts.models import setup_model +from TTS.tts.utils.io import save_best_model, save_checkpoint +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.symbols import make_symbols +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor +from TTS.utils.distribute import init_distributed +from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict, to_cuda +from TTS.utils.logging import ConsoleLogger, TensorboardLogger +from TTS.utils.training import check_update, setup_torch_training_env + + +# pylint: disable=import-outside-toplevel, too-many-public-methods + +class TrainerTTS(TrainerAbstract): + use_cuda, num_gpus = setup_torch_training_env(True, False) + + def __init__( + self, + args: Union[Coqpit, Namespace], + config: Coqpit, + c_logger: ConsoleLogger = None, + tb_logger: TensorboardLogger = None, + model: nn.Module = None, + output_path: str = None, + ) -> None: + self.args = args + self.config = config + self.c_logger = ConsoleLogger() if c_logger is None else c_logger + if tb_logger is None: + self.tb_logger = TensorboardLogger(output_path, model_name=config.model) + self.tb_logger.tb_add_text("model-config", f"
{config.to_json()}", 0) + else: + self.tb_logger = tb_logger + self.output_path = output_path + + self.total_steps_done = 0 + self.epochs_done = 0 + self.restore_step = 0 + self.best_loss = float("inf") + self.train_loader = None + self.eval_loader = None + self.output_audio_path = os.path.join(output_path, "test_audios") + + self.keep_avg_train = None + self.keep_avg_eval = None + + log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") + self._setup_logger_config(log_file) + + # model, audio processor, datasets, loss + # init audio processor + self.ap = AudioProcessor(**self.config.audio.to_dict()) + + # init character processor + self.model_characters = self.get_character_processor(self.config) + + # load dataset samples + self.data_train, self.data_eval = load_meta_data(self.config.datasets) + + # default speaker manager + self.speaker_manager = self.get_speaker_manager(self.config, args.restore_path, output_path, self.data_train) + + # init TTS model + if model is not None: + self.model = model + else: + self.model = self.get_model( + len(self.model_characters), + self.speaker_manager.num_speakers, + self.config, + self.speaker_manager.d_vector_dim if self.speaker_manager.d_vectors else None, + ) + + # setup criterion + self.criterion = self.get_criterion(self.config) + + # DISTRUBUTED + if self.num_gpus > 1: + init_distributed( + args.rank, + self.num_gpus, + args.group_id, + self.config.distributed_backend, + self.config.distributed_url, + ) + + if self.use_cuda: + self.model.cuda() + self.criterion.cuda() + + # scalers for mixed precision training + self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None + + # setup optimizer + self.optimizer = self.get_optimizer(self.model, self.config) + + if self.args.restore_path: + self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model( + self.config, args.restore_path, self.model, self.optimizer, self.scaler + ) + + # setup scheduler + self.scheduler = self.get_scheduler(self.config, self.optimizer) + + # DISTRUBUTED + if self.num_gpus > 1: + self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) + + # count model size + num_params = count_parameters(self.model) + print("\n > Model has {} parameters".format(num_params)) + + @staticmethod + def get_model(num_chars: int, num_speakers: int, config: Coqpit, d_vector_dim: int) -> nn.Module: + model = setup_model(num_chars, num_speakers, config, d_vector_dim) + return model + + @staticmethod + def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer: + optimizer_name = config.optimizer + optimizer_params = config.optimizer_params + if optimizer_name.lower() == "radam": + module = importlib.import_module("TTS.utils.radam") + optimizer = getattr(module, "RAdam") + else: + optimizer = getattr(torch.optim, optimizer_name) + return optimizer(model.parameters(), lr=config.lr, **optimizer_params) + + @staticmethod + def get_character_processor(config: Coqpit) -> str: + # setup custom characters if set in config file. + # TODO: implement CharacterProcessor + if config.characters is not None: + symbols, phonemes = make_symbols(**config.characters.to_dict()) + else: + from TTS.tts.utils.text.symbols import phonemes, symbols + model_characters = phonemes if config.use_phonemes else symbols + return model_characters + + @staticmethod + def get_speaker_manager( + config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None + ) -> SpeakerManager: + speaker_manager = get_speaker_manager(config, restore_path, data_train, out_path) + return speaker_manager + + @staticmethod + def get_scheduler( + config: Coqpit, optimizer: torch.optim.Optimizer + ) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access + lr_scheduler = config.lr_scheduler + lr_scheduler_params = config.lr_scheduler_params + if lr_scheduler is None: + return None + if lr_scheduler.lower() == "noamlr": + from TTS.utils.training import NoamLR + + scheduler = NoamLR + else: + scheduler = getattr(torch.optim, lr_scheduler) + return scheduler(optimizer, **lr_scheduler_params) + + @staticmethod + def get_criterion(config: Coqpit) -> nn.Module: + return setup_loss(config) + + def restore_model( + self, + config: Coqpit, + restore_path: str, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scaler: torch.cuda.amp.GradScaler = None, + ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: + print(" > Restoring from %s ..." % os.path.basename(restore_path)) + checkpoint = torch.load(restore_path) + try: + print(" > Restoring Model...") + model.load_state_dict(checkpoint["model"]) + print(" > Restoring Optimizer...") + optimizer.load_state_dict(checkpoint["optimizer"]) + if "scaler" in checkpoint and config.mixed_precision: + print(" > Restoring AMP Scaler...") + scaler.load_state_dict(checkpoint["scaler"]) + except (KeyError, RuntimeError): + print(" > Partial model initialization...") + model_dict = model.state_dict() + model_dict = set_init_dict(model_dict, checkpoint["model"], config) + model.load_state_dict(model_dict) + del model_dict + + for group in optimizer.param_groups: + group["lr"] = self.config.lr + print( + " > Model restored from step %d" % checkpoint["step"], + ) + restore_step = checkpoint["step"] + return model, optimizer, scaler, restore_step + + def _get_loader( + self, + r: int, + ap: AudioProcessor, + is_eval: bool, + data_items: List, + verbose: bool, + speaker_ids: Union[Dict, List], + d_vectors: Union[Dict, List], + ) -> DataLoader: + if is_eval and not self.config.run_eval: + loader = None + else: + dataset = TTSDataset( + outputs_per_step=r, + text_cleaner=self.config.text_cleaner, + compute_linear_spec=self.config.model.lower() == "tacotron", + meta_data=data_items, + ap=ap, + tp=self.config.characters, + add_blank=self.config["add_blank"], + batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size, + min_seq_len=self.config.min_seq_len, + max_seq_len=self.config.max_seq_len, + phoneme_cache_path=self.config.phoneme_cache_path, + use_phonemes=self.config.use_phonemes, + phoneme_language=self.config.phoneme_language, + enable_eos_bos=self.config.enable_eos_bos_chars, + use_noise_augment=not is_eval, + verbose=verbose, + speaker_id_mapping=speaker_ids if self.config.use_speaker_embedding else None, + d_vector_mapping=d_vectors + if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file + else None, + ) + + if self.config.use_phonemes and self.config.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(self.config.num_loader_workers) + dataset.sort_items() + + sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + sampler=sampler, + num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_train_dataloader( + self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict + ) -> DataLoader: + return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors) + + def get_eval_dataloder( + self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict + ) -> DataLoader: + return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors) + + def format_batch(self, batch: List) -> Dict: + # setup input batch + text_input = batch[0] + text_lengths = batch[1] + speaker_names = batch[2] + linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None + mel_input = batch[4] + mel_lengths = batch[5] + stop_targets = batch[6] + item_idx = batch[7] + d_vectors = batch[8] + speaker_ids = batch[9] + attn_mask = batch[10] + max_text_length = torch.max(text_lengths.float()) + max_spec_length = torch.max(mel_lengths.float()) + + # compute durations from attention masks + durations = None + if attn_mask is not None: + durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) + for idx, am in enumerate(attn_mask): + # compute raw durations + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] + # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) + c_idxs, counts = torch.unique(c_idxs, return_counts=True) + dur = torch.ones([text_lengths[idx]]).to(counts.dtype) + dur[c_idxs] = counts + # smooth the durations and set any 0 duration to 1 + # by cutting off from the largest duration indeces. + extra_frames = dur.sum() - mel_lengths[idx] + largest_idxs = torch.argsort(-dur)[:extra_frames] + dur[largest_idxs] -= 1 + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur + + # set stop targets view, we predict a single stop token per iteration. + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + + # dispatch batch to GPU + if self.use_cuda: + text_input = to_cuda(text_input) + text_lengths = to_cuda(text_lengths) + mel_input = to_cuda(mel_input) + mel_lengths = to_cuda(mel_lengths) + linear_input = to_cuda(linear_input) if self.config.model.lower() in ["tacotron"] else None + stop_targets = to_cuda(stop_targets) + attn_mask = to_cuda(attn_mask) if attn_mask is not None else None + durations = to_cuda(durations) if attn_mask is not None else None + if speaker_ids is not None: + speaker_ids = to_cuda(speaker_ids) + if d_vectors is not None: + d_vectors = to_cuda(d_vectors) + + return { + "text_input": text_input, + "text_lengths": text_lengths, + "speaker_names": speaker_names, + "mel_input": mel_input, + "mel_lengths": mel_lengths, + "linear_input": linear_input, + "stop_targets": stop_targets, + "attn_mask": attn_mask, + "durations": durations, + "speaker_ids": speaker_ids, + "d_vectors": d_vectors, + "max_text_length": max_text_length, + "max_spec_length": max_spec_length, + "item_idx": item_idx, + } + + def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: + if hasattr(self.model, "module"): + return self.model.module.train_step(batch, criterion) + return self.model.train_step(batch, criterion) + + def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: + self.on_train_step_start() + step_start_time = time.time() + + # format data + batch = self.format_batch(batch) + loader_time = time.time() - loader_start_time + + # zero-out optimizer + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.config.mixed_precision): + outputs, loss_dict = self._train_step(batch, self.criterion) + + # check nan loss + if torch.isnan(loss_dict["loss"]).any(): + raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.") + + # optimizer step + if self.config.mixed_precision: + # model optimizer step in mixed precision mode + self.scaler.scale(loss_dict["loss"]).backward() + self.scaler.unscale_(self.optimizer) + grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + # main model optimizer step + loss_dict["loss"].backward() + grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True) + self.optimizer.step() + + step_time = time.time() - step_start_time + + # setup lr + if self.config.lr_scheduler: + self.scheduler.step() + + # detach loss values + loss_dict_new = dict() + for key, value in loss_dict.items(): + if isinstance(value, (int, float)): + loss_dict_new[key] = value + else: + loss_dict_new[key] = value.item() + loss_dict = loss_dict_new + + # update avg stats + update_train_values = dict() + for key, value in loss_dict.items(): + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time + self.keep_avg_train.update_values(update_train_values) + + # print training progress + current_lr = self.optimizer.param_groups[0]["lr"] + if self.total_steps_done % self.config.print_step == 0: + log_dict = { + "max_spec_length": [batch["max_spec_length"], 1], # value, precision + "max_text_length": [batch["max_text_length"], 1], + "step_time": [step_time, 4], + "loader_time": [loader_time, 2], + "current_lr": current_lr, + } + self.c_logger.print_train_step( + batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values + ) + + if self.args.rank == 0: + # Plot Training Iter Stats + # reduce TB load + if self.total_steps_done % self.config.tb_plot_step == 0: + iter_stats = { + "lr": current_lr, + "grad_norm": grad_norm, + "step_time": step_time, + } + iter_stats.update(loss_dict) + self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats) + + if self.total_steps_done % self.config.save_step == 0: + if self.config.checkpoint: + # save model + save_checkpoint( + self.model, + self.optimizer, + self.total_steps_done, + self.epochs_done, + self.config.r, + self.output_path, + model_loss=loss_dict["loss"], + characters=self.model_characters, + scaler=self.scaler.state_dict() if self.config.mixed_precision else None, + ) + # training visualizations + if hasattr(self.model, "module"): + figures, audios = self.model.module.train_log(self.ap, batch, outputs) + else: + figures, audios = self.model.train_log(self.ap, batch, outputs) + self.tb_logger.tb_train_figures(self.total_steps_done, figures) + self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate) + self.total_steps_done += 1 + self.on_train_step_end() + return outputs, loss_dict + + def train_epoch(self) -> None: + self.model.train() + epoch_start_time = time.time() + if self.use_cuda: + batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus)) + else: + batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) + self.c_logger.print_train_start() + loader_start_time = time.time() + for cur_step, batch in enumerate(self.train_loader): + _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) + epoch_time = time.time() - epoch_start_time + # Plot self.epochs_done Stats + if self.args.rank == 0: + epoch_stats = {"epoch_time": epoch_time} + epoch_stats.update(self.keep_avg_train.avg_values) + self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats) + if self.config.tb_model_param_stats: + self.tb_logger.tb_model_weights(self.model, self.total_steps_done) + + def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]: + if hasattr(self.model, "module"): + return self.model.module.eval_step(batch, self.criterion) + return self.model.eval_step(batch, self.criterion) + + def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: + with torch.no_grad(): + step_start_time = time.time() + + with torch.cuda.amp.autocast(enabled=self.config.mixed_precision): + outputs, loss_dict = self._eval_step(batch) + + step_time = time.time() - step_start_time + + # detach loss values + loss_dict_new = dict() + for key, value in loss_dict.items(): + if isinstance(value, (int, float)): + loss_dict_new[key] = value + else: + loss_dict_new[key] = value.item() + loss_dict = loss_dict_new + + # update avg stats + update_eval_values = dict() + for key, value in loss_dict.items(): + update_eval_values["avg_" + key] = value + update_eval_values["avg_step_time"] = step_time + self.keep_avg_eval.update_values(update_eval_values) + + if self.config.print_eval: + self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values) + return outputs, loss_dict + + def eval_epoch(self) -> None: + self.model.eval() + self.c_logger.print_eval_start() + loader_start_time = time.time() + batch = None + for cur_step, batch in enumerate(self.eval_loader): + # format data + batch = self.format_batch(batch) + loader_time = time.time() - loader_start_time + self.keep_avg_eval.update_values({"avg_loader_time": loader_time}) + outputs, _ = self.eval_step(batch, cur_step) + # Plot epoch stats and samples from the last batch. + if self.args.rank == 0: + if hasattr(self.model, "module"): + figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs) + else: + figures, eval_audios = self.model.eval_log(self.ap, batch, outputs) + self.tb_logger.tb_eval_figures(self.total_steps_done, figures) + self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate) + + def test_run( + self, + ) -> None: + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_aux_inputs() + for idx, sen in enumerate(test_sentences): + wav, alignment, model_outputs, _ = synthesis( + self.model, + sen, + self.config, + self.use_cuda, + self.ap, + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + enable_eos_bos_chars=self.config.enable_eos_bos_chars, + use_griffin_lim=True, + do_trim_silence=False, + ).values() + + file_path = os.path.join(self.output_audio_path, str(self.total_steps_done)) + os.makedirs(file_path, exist_ok=True) + file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) + self.ap.save_wav(wav, file_path) + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False) + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) + + self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"]) + self.tb_logger.tb_test_figures(self.total_steps_done, test_figures) + + def _get_aux_inputs(self) -> Dict: + # setup speaker_id + speaker_id = 0 if self.config.use_speaker_embedding else None + # setup d_vector + d_vector = ( + self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0]) + if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding + else None + ) + # setup style_mel + if self.config.has("gst_style_input"): + style_wav = self.config.gst_style_input + else: + style_wav = None + if style_wav is None and "use_gst" in self.config and self.config.use_gst: + # inicialize GST with zero dict. + style_wav = {} + print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") + for i in range(self.config.gst["gst_num_style_tokens"]): + style_wav[str(i)] = 0 + aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector} + return aux_inputs + + def fit(self) -> None: + if self.restore_step != 0 or self.args.best_path: + print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") + self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"] + print(f" > Starting with loaded last best loss {self.best_loss}.") + + # define data loaders + self.train_loader = self.get_train_dataloader( + self.config.r, + self.ap, + self.data_train, + verbose=True, + speaker_ids=self.speaker_manager.speaker_ids, + d_vectors=self.speaker_manager.d_vectors, + ) + self.eval_loader = ( + self.get_eval_dataloder( + self.config.r, + self.ap, + self.data_train, + verbose=True, + speaker_ids=self.speaker_manager.speaker_ids, + d_vectors=self.speaker_manager.d_vectors, + ) + if self.config.run_eval + else None + ) + + self.total_steps_done = self.restore_step + + for epoch in range(0, self.config.epochs): + self.on_epoch_start() + self.keep_avg_train = KeepAverage() + self.keep_avg_eval = KeepAverage() if self.config.run_eval else None + self.epochs_done = epoch + self.c_logger.print_epoch_start(epoch, self.config.epochs) + self.train_epoch() + if self.config.run_eval: + self.eval_epoch() + if epoch >= self.config.test_delay_epochs and self.args.rank < 0: + self.test_run() + self.c_logger.print_epoch_end( + epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values + ) + self.save_best_model() + self.on_epoch_end() + + def save_best_model(self) -> None: + self.best_loss = save_best_model( + self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"], + self.best_loss, + self.model, + self.optimizer, + self.total_steps_done, + self.epochs_done, + self.config.r, + self.output_path, + self.model_characters, + keep_all_best=self.config.keep_all_best, + keep_after=self.config.keep_after, + scaler=self.scaler.state_dict() if self.config.mixed_precision else None, + ) + + @staticmethod + def _setup_logger_config(log_file: str) -> None: + logging.basicConfig( + level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()] + ) + + def on_epoch_start(self) -> None: # pylint: disable=no-self-use + if hasattr(self.model, "on_epoch_start"): + self.model.on_epoch_start(self) + + if hasattr(self.criterion, "on_epoch_start"): + self.criterion.on_epoch_start(self) + + if hasattr(self.optimizer, "on_epoch_start"): + self.optimizer.on_epoch_start(self) + + def on_epoch_end(self) -> None: # pylint: disable=no-self-use + if hasattr(self.model, "on_epoch_end"): + self.model.on_epoch_end(self) + + if hasattr(self.criterion, "on_epoch_end"): + self.criterion.on_epoch_end(self) + + if hasattr(self.optimizer, "on_epoch_end"): + self.optimizer.on_epoch_end(self) + + def on_train_step_start(self) -> None: # pylint: disable=no-self-use + if hasattr(self.model, "on_train_step_start"): + self.model.on_train_step_start(self) + + if hasattr(self.criterion, "on_train_step_start"): + self.criterion.on_train_step_start(self) + + if hasattr(self.optimizer, "on_train_step_start"): + self.optimizer.on_train_step_start(self) + + def on_train_step_end(self) -> None: # pylint: disable=no-self-use + if hasattr(self.model, "on_train_step_end"): + self.model.on_train_step_end(self) + + if hasattr(self.criterion, "on_train_step_end"): + self.criterion.on_train_step_end(self) + + if hasattr(self.optimizer, "on_train_step_end"): + self.optimizer.on_train_step_end(self)