# -*- coding: utf-8 -*- import importlib import logging import os import time from argparse import Namespace from dataclasses import dataclass, field from typing import Dict, List, Tuple, Union import torch from coqpit import Coqpit # DISTRIBUTED from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from TTS.tts.datasets import TTSDataset, load_meta_data from TTS.tts.layers import setup_loss from TTS.tts.models import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict from TTS.utils.logging import ConsoleLogger, TensorboardLogger from TTS.utils.training import check_update, setup_torch_training_env @dataclass class TrainingArgs(Coqpit): continue_path: str = field( default="", metadata={ "help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder." }, ) restore_path: str = field( default="", metadata={ "help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training." }, ) best_path: str = field( default="", metadata={ "help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used" }, ) config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) # pylint: disable=import-outside-toplevel, too-many-public-methods class TrainerTTS: use_cuda, num_gpus = setup_torch_training_env(True, False) def __init__( self, args: Union[Coqpit, Namespace], config: Coqpit, c_logger: ConsoleLogger, tb_logger: TensorboardLogger, model: nn.Module = None, output_path: str = None, ) -> None: self.args = args self.config = config self.c_logger = c_logger self.tb_logger = tb_logger self.output_path = output_path self.total_steps_done = 0 self.epochs_done = 0 self.restore_step = 0 self.best_loss = float("inf") self.train_loader = None self.eval_loader = None self.output_audio_path = os.path.join(output_path, "test_audios") self.keep_avg_train = None self.keep_avg_eval = None log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") self._setup_logger_config(log_file) # model, audio processor, datasets, loss # init audio processor self.ap = AudioProcessor(**self.config.audio.to_dict()) # init character processor self.model_characters = self.get_character_processor(self.config) # load dataset samples self.data_train, self.data_eval = load_meta_data(self.config.datasets) # default speaker manager self.speaker_manager = self.get_speaker_manager( self.config, args.restore_path, self.config.output_path, self.data_train ) # init TTS model if model is not None: self.model = model else: self.model = self.get_model( len(self.model_characters), self.speaker_manager.num_speakers, self.config, self.speaker_manager.x_vector_dim if self.speaker_manager.x_vectors else None, ) # setup criterion self.criterion = self.get_criterion(self.config) if self.use_cuda: self.model.cuda() self.criterion.cuda() # DISTRUBUTED if self.num_gpus > 1: init_distributed( args.rank, self.num_gpus, args.group_id, self.config.distributed["backend"], self.config.distributed["url"], ) # scalers for mixed precision training self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None # setup optimizer self.optimizer = self.get_optimizer(self.model, self.config) if self.args.restore_path: self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model( self.config, args.restore_path, self.model, self.optimizer, self.scaler ) # setup scheduler self.scheduler = self.get_scheduler(self.config, self.optimizer) # DISTRUBUTED if self.num_gpus > 1: self.model = DDP_th(self.model, device_ids=[args.rank]) # count model size num_params = count_parameters(self.model) print("\n > Model has {} parameters".format(num_params)) @staticmethod def get_model(num_chars: int, num_speakers: int, config: Coqpit, x_vector_dim: int) -> nn.Module: model = setup_model(num_chars, num_speakers, config, x_vector_dim) return model @staticmethod def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer: optimizer_name = config.optimizer optimizer_params = config.optimizer_params if optimizer_name.lower() == "radam": module = importlib.import_module("TTS.utils.radam") optimizer = getattr(module, "RAdam") else: optimizer = getattr(torch.optim, optimizer_name) return optimizer(model.parameters(), lr=config.lr, **optimizer_params) @staticmethod def get_character_processor(config: Coqpit) -> str: # setup custom characters if set in config file. # TODO: implement CharacterProcessor if config.characters is not None: symbols, phonemes = make_symbols(**config.characters.to_dict()) else: from TTS.tts.utils.text.symbols import phonemes, symbols model_characters = phonemes if config.use_phonemes else symbols return model_characters @staticmethod def get_speaker_manager( config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None ) -> SpeakerManager: speaker_manager = SpeakerManager() if restore_path: speakers_file = os.path.join(os.path.dirname(restore_path), "speaker.json") if not os.path.exists(speakers_file): print( "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file" ) speakers_file = config.external_speaker_embedding_file if config.use_external_speaker_embedding_file: speaker_manager.load_x_vectors_file(speakers_file) else: speaker_manager.load_ids_file(speakers_file) elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file: speaker_manager.load_x_vectors_file(config.external_speaker_embedding_file) else: speaker_manager.parse_speakers_from_items(data_train) file_path = os.path.join(out_path, "speakers.json") speaker_manager.save_ids_file(file_path) return speaker_manager @staticmethod def get_scheduler( config: Coqpit, optimizer: torch.optim.Optimizer ) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access lr_scheduler = config.lr_scheduler lr_scheduler_params = config.lr_scheduler_params if lr_scheduler is None: return None if lr_scheduler.lower() == "noamlr": from TTS.utils.training import NoamLR scheduler = NoamLR else: scheduler = getattr(torch.optim, lr_scheduler) return scheduler(optimizer, **lr_scheduler_params) @staticmethod def get_criterion(config: Coqpit) -> nn.Module: return setup_loss(config) def restore_model( self, config: Coqpit, restore_path: str, model: nn.Module, optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler = None, ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: print(" > Restoring from %s ..." % os.path.basename(restore_path)) checkpoint = torch.load(restore_path) try: print(" > Restoring Model...") model.load_state_dict(checkpoint["model"]) print(" > Restoring Optimizer...") optimizer.load_state_dict(checkpoint["optimizer"]) if "scaler" in checkpoint and config.mixed_precision: print(" > Restoring AMP Scaler...") scaler.load_state_dict(checkpoint["scaler"]) except (KeyError, RuntimeError): print(" > Partial model initialization...") model_dict = model.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model"], config) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: group["lr"] = self.config.lr print( " > Model restored from step %d" % checkpoint["step"], ) restore_step = checkpoint["step"] return model, optimizer, scaler, restore_step def _get_loader( self, r: int, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, speaker_mapping: Union[Dict, List], ) -> DataLoader: if is_eval and not self.config.run_eval: loader = None else: dataset = TTSDataset( outputs_per_step=r, text_cleaner=self.config.text_cleaner, compute_linear_spec=self.config.model.lower() == "tacotron", meta_data=data_items, ap=ap, tp=self.config.characters, add_blank=self.config["add_blank"], batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size, min_seq_len=self.config.min_seq_len, max_seq_len=self.config.max_seq_len, phoneme_cache_path=self.config.phoneme_cache_path, use_phonemes=self.config.use_phonemes, phoneme_language=self.config.phoneme_language, enable_eos_bos=self.config.enable_eos_bos_chars, use_noise_augment=not is_eval, verbose=verbose, speaker_mapping=speaker_mapping if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file else None, ) if self.config.use_phonemes and self.config.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. dataset.compute_input_seq(self.config.num_loader_workers) dataset.sort_items() sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None loader = DataLoader( dataset, batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, sampler=sampler, num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers, pin_memory=False, ) return loader def get_train_dataloader( self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict] ) -> DataLoader: return self._get_loader(r, ap, False, data_items, verbose, speaker_mapping) def get_eval_dataloder( self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict] ) -> DataLoader: return self._get_loader(r, ap, True, data_items, verbose, speaker_mapping) def format_batch(self, batch: List) -> Dict: # setup input batch text_input = batch[0] text_lengths = batch[1] speaker_names = batch[2] linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None mel_input = batch[4] mel_lengths = batch[5] stop_targets = batch[6] item_idx = batch[7] speaker_embeddings = batch[8] attn_mask = batch[9] max_text_length = torch.max(text_lengths.float()) max_spec_length = torch.max(mel_lengths.float()) # convert speaker names to ids if self.config.use_speaker_embedding: if self.config.use_external_speaker_embedding_file: speaker_embeddings = batch[8] speaker_ids = None else: speaker_ids = [self.speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names] speaker_ids = torch.LongTensor(speaker_ids) speaker_embeddings = None else: speaker_embeddings = None speaker_ids = None # compute durations from attention masks durations = None if attn_mask is not None: durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) for idx, am in enumerate(attn_mask): # compute raw durations c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) c_idxs, counts = torch.unique(c_idxs, return_counts=True) dur = torch.ones([text_lengths[idx]]).to(counts.dtype) dur[c_idxs] = counts # smooth the durations and set any 0 duration to 1 # by cutting off from the largest duration indeces. extra_frames = dur.sum() - mel_lengths[idx] largest_idxs = torch.argsort(-dur)[:extra_frames] dur[largest_idxs] -= 1 assert ( dur.sum() == mel_lengths[idx] ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" durations[idx, : text_lengths[idx]] = dur # set stop targets view, we predict a single stop token per iteration. stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) # dispatch batch to GPU if self.use_cuda: text_input = text_input.cuda(non_blocking=True) text_lengths = text_lengths.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True) linear_input = linear_input.cuda(non_blocking=True) if self.config.model.lower() in ["tacotron"] else None stop_targets = stop_targets.cuda(non_blocking=True) attn_mask = attn_mask.cuda(non_blocking=True) if attn_mask is not None else None durations = durations.cuda(non_blocking=True) if attn_mask is not None else None if speaker_ids is not None: speaker_ids = speaker_ids.cuda(non_blocking=True) if speaker_embeddings is not None: speaker_embeddings = speaker_embeddings.cuda(non_blocking=True) return { "text_input": text_input, "text_lengths": text_lengths, "mel_input": mel_input, "mel_lengths": mel_lengths, "linear_input": linear_input, "stop_targets": stop_targets, "attn_mask": attn_mask, "durations": durations, "speaker_ids": speaker_ids, "x_vectors": speaker_embeddings, "max_text_length": max_text_length, "max_spec_length": max_spec_length, "item_idx": item_idx, } def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: self.on_train_step_start() step_start_time = time.time() # format data batch = self.format_batch(batch) loader_time = time.time() - loader_start_time # zero-out optimizer self.optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=self.config.mixed_precision): outputs, loss_dict = self.model.train_step(batch, self.criterion) # check nan loss if torch.isnan(loss_dict["loss"]).any(): raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.") # optimizer step if self.config.mixed_precision: # model optimizer step in mixed precision mode self.scaler.scale(loss_dict["loss"]).backward() self.scaler.unscale_(self.optimizer) grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True) self.scaler.step(self.optimizer) self.scaler.update() else: # main model optimizer step loss_dict["loss"].backward() grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True) self.optimizer.step() step_time = time.time() - step_start_time # setup lr if self.config.lr_scheduler: self.scheduler.step() # detach loss values loss_dict_new = dict() for key, value in loss_dict.items(): if isinstance(value, (int, float)): loss_dict_new[key] = value else: loss_dict_new[key] = value.item() loss_dict = loss_dict_new # update avg stats update_train_values = dict() for key, value in loss_dict.items(): update_train_values["avg_" + key] = value update_train_values["avg_loader_time"] = loader_time update_train_values["avg_step_time"] = step_time self.keep_avg_train.update_values(update_train_values) # print training progress current_lr = self.optimizer.param_groups[0]["lr"] if self.total_steps_done % self.config.print_step == 0: log_dict = { "max_spec_length": [batch["max_spec_length"], 1], # value, precision "max_text_length": [batch["max_text_length"], 1], "step_time": [step_time, 4], "loader_time": [loader_time, 2], "current_lr": current_lr, } self.c_logger.print_train_step( batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values ) if self.args.rank == 0: # Plot Training Iter Stats # reduce TB load if self.total_steps_done % self.config.tb_plot_step == 0: iter_stats = { "lr": current_lr, "grad_norm": grad_norm, "step_time": step_time, } iter_stats.update(loss_dict) self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats) if self.total_steps_done % self.config.save_step == 0: if self.config.checkpoint: # save model save_checkpoint( self.model, self.optimizer, self.total_steps_done, self.epochs_done, self.config.r, self.output_path, model_loss=loss_dict["loss"], characters=self.model_characters, scaler=self.scaler.state_dict() if self.config.mixed_precision else None, ) # training visualizations figures, audios = self.model.train_log(self.ap, batch, outputs) self.tb_logger.tb_train_figures(self.total_steps_done, figures) self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate) self.total_steps_done += 1 self.on_train_step_end() return outputs, loss_dict def train_epoch(self) -> None: self.model.train() epoch_start_time = time.time() if self.use_cuda: batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus)) else: batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) self.c_logger.print_train_start() loader_start_time = time.time() for cur_step, batch in enumerate(self.train_loader): _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) epoch_time = time.time() - epoch_start_time # Plot self.epochs_done Stats if self.args.rank == 0: epoch_stats = {"epoch_time": epoch_time} epoch_stats.update(self.keep_avg_train.avg_values) self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats) if self.config.tb_model_param_stats: self.tb_logger.tb_model_weights(self.model, self.total_steps_done) def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: with torch.no_grad(): step_start_time = time.time() with torch.cuda.amp.autocast(enabled=self.config.mixed_precision): outputs, loss_dict = self.model.eval_step(batch, self.criterion) step_time = time.time() - step_start_time # detach loss values loss_dict_new = dict() for key, value in loss_dict.items(): if isinstance(value, (int, float)): loss_dict_new[key] = value else: loss_dict_new[key] = value.item() loss_dict = loss_dict_new # update avg stats update_eval_values = dict() for key, value in loss_dict.items(): update_eval_values["avg_" + key] = value update_eval_values["avg_step_time"] = step_time self.keep_avg_eval.update_values(update_eval_values) if self.config.print_eval: self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values) return outputs, loss_dict def eval_epoch(self) -> None: self.model.eval() self.c_logger.print_eval_start() loader_start_time = time.time() batch = None for cur_step, batch in enumerate(self.eval_loader): # format data batch = self.format_batch(batch) loader_time = time.time() - loader_start_time self.keep_avg_eval.update_values({"avg_loader_time": loader_time}) outputs, _ = self.eval_step(batch, cur_step) # Plot epoch stats and samples from the last batch. if self.args.rank == 0: figures, eval_audios = self.model.eval_log(self.ap, batch, outputs) self.tb_logger.tb_eval_figures(self.total_steps_done, figures) self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate) def test_run( self, ) -> None: print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences cond_inputs = self._get_cond_inputs() for idx, sen in enumerate(test_sentences): wav, alignment, model_outputs, _ = synthesis( self.model, sen, self.config, self.use_cuda, self.ap, speaker_id=cond_inputs["speaker_id"], x_vector=cond_inputs["x_vector"], style_wav=cond_inputs["style_wav"], enable_eos_bos_chars=self.config.enable_eos_bos_chars, use_griffin_lim=True, do_trim_silence=False, ).values() file_path = os.path.join(self.output_audio_path, str(self.total_steps_done)) os.makedirs(file_path, exist_ok=True) file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) self.ap.save_wav(wav, file_path) test_audios["{}-audio".format(idx)] = wav test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False) test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_figures(self.total_steps_done, test_figures) def _get_cond_inputs(self) -> Dict: # setup speaker_id speaker_id = 0 if self.config.use_speaker_embedding else None # setup x_vector x_vector = ( self.speaker_manager.get_x_vectors_by_speaker(self.speaker_manager.speaker_ids[0]) if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding else None ) # setup style_mel if self.config.has("gst_style_input"): style_wav = self.config.gst_style_input else: style_wav = None if style_wav is None and "use_gst" in self.config and self.config.use_gst: # inicialize GST with zero dict. style_wav = {} print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") for i in range(self.config.gst["gst_num_style_tokens"]): style_wav[str(i)] = 0 cond_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "x_vector": x_vector} return cond_inputs def fit(self) -> None: if self.restore_step != 0 or self.args.best_path: print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {self.best_loss}.") # define data loaders self.train_loader = self.get_train_dataloader( self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids ) self.eval_loader = ( self.get_eval_dataloder( self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids ) if self.config.run_eval else None ) self.total_steps_done = self.restore_step for epoch in range(0, self.config.epochs): self.on_epoch_start() self.keep_avg_train = KeepAverage() self.keep_avg_eval = KeepAverage() if self.config.run_eval else None self.epochs_done = epoch self.c_logger.print_epoch_start(epoch, self.config.epochs) self.train_epoch() if self.config.run_eval: self.eval_epoch() if epoch >= self.config.test_delay_epochs: self.test_run() self.c_logger.print_epoch_end( epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values ) self.save_best_model() self.on_epoch_end() def save_best_model(self) -> None: self.best_loss = save_best_model( self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"], self.best_loss, self.model, self.optimizer, self.total_steps_done, self.epochs_done, self.config.r, self.output_path, self.model_characters, keep_all_best=self.config.keep_all_best, keep_after=self.config.keep_after, scaler=self.scaler.state_dict() if self.config.mixed_precision else None, ) @staticmethod def _setup_logger_config(log_file: str) -> None: logging.basicConfig( level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()] ) def on_epoch_start(self) -> None: # pylint: disable=no-self-use if hasattr(self.model, "on_epoch_start"): self.model.on_epoch_start(self) if hasattr(self.criterion, "on_epoch_start"): self.criterion.on_epoch_start(self) if hasattr(self.optimizer, "on_epoch_start"): self.optimizer.on_epoch_start(self) def on_epoch_end(self) -> None: # pylint: disable=no-self-use if hasattr(self.model, "on_epoch_end"): self.model.on_epoch_end(self) if hasattr(self.criterion, "on_epoch_end"): self.criterion.on_epoch_end(self) if hasattr(self.optimizer, "on_epoch_end"): self.optimizer.on_epoch_end(self) def on_train_step_start(self) -> None: # pylint: disable=no-self-use if hasattr(self.model, "on_train_step_start"): self.model.on_train_step_start(self) if hasattr(self.criterion, "on_train_step_start"): self.criterion.on_train_step_start(self) if hasattr(self.optimizer, "on_train_step_start"): self.optimizer.on_train_step_start(self) def on_train_step_end(self) -> None: # pylint: disable=no-self-use if hasattr(self.model, "on_train_step_end"): self.model.on_train_step_end(self) if hasattr(self.criterion, "on_train_step_end"): self.criterion.on_train_step_end(self) if hasattr(self.optimizer, "on_train_step_end"): self.optimizer.on_train_step_end(self)