diff --git a/TTS/trainer.py b/TTS/trainer.py index cfb72191..3beb281f 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -65,12 +65,12 @@ class TrainerTTS: use_cuda, num_gpus = setup_torch_training_env(True, False) def __init__(self, - args, - config, - c_logger, - tb_logger, - model=None, - output_path=None): + args: Union[Coqpit, Namespace], + config: Coqpit, + c_logger: ConsoleLogger, + tb_logger: TensorboardLogger, + model: nn.Module = None, + output_path: str = None) -> None: self.args = args self.config = config self.c_logger = c_logger @@ -88,43 +88,52 @@ class TrainerTTS: self.keep_avg_train = None self.keep_avg_eval = None + log_file = os.path.join(self.output_path, + f"trainer_{args.rank}_log.txt") + self._setup_logger_config(log_file) + # model, audio processor, datasets, loss # init audio processor - self.ap = AudioProcessor(**config.audio.to_dict()) + self.ap = AudioProcessor(**self.config.audio.to_dict()) # init character processor - self.model_characters = self.init_character_processor() + self.model_characters = self.get_character_processor(self.config) # load dataset samples - self.data_train, self.data_eval = load_meta_data(config.datasets) + self.data_train, self.data_eval = load_meta_data(self.config.datasets) # default speaker manager - self.speaker_manager = self.init_speaker_manager() + self.speaker_manager = self.get_speaker_manager( + self.config, args.restore_path, self.config.output_path, self.data_train) # init TTS model if model is not None: self.model = model else: - self.model = self.init_model() + self.model = self.get_model( + len(self.model_characters), self.speaker_manager.num_speakers, + self.config, self.speaker_manager.x_vector_dim + if self.speaker_manager.x_vectors else None) # setup criterion - self.criterion = self.init_criterion() + self.criterion = self.get_criterion(self.config) + + if self.use_cuda: + self.model.cuda() + self.criterion.cuda() # DISTRUBUTED if self.num_gpus > 1: init_distributed(args.rank, self.num_gpus, args.group_id, - config.distributed["backend"], - config.distributed["url"]) + self.config.distributed["backend"], + self.config.distributed["url"]) # scalers for mixed precision training self.scaler = torch.cuda.amp.GradScaler( - ) if config.mixed_precision else None + ) if self.config.mixed_precision and self.use_cuda else None # setup optimizer - self.optimizer = self.init_optimizer(self.model) - - # setup scheduler - self.scheduler = self.init_scheduler(self.config, self.optimizer) + self.optimizer = self.get_optimizer(self.model, self.config) if self.args.restore_path: self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model( @@ -144,64 +153,66 @@ class TrainerTTS: logging.info("\n > Model has {} parameters".format(num_params), flush=True) - def init_model(self): - model = setup_model( - len(self.model_characters), - self.speaker_manager.num_speakers, - self.config, - self.speaker_manager.x_vector_dim - if self.speaker_manager.x_vectors else None, - ) + @staticmethod + def get_model(num_chars: int, num_speakers: int, config: Coqpit, + x_vector_dim: int) -> nn.Module: + model = setup_model(num_chars, num_speakers, config, x_vector_dim) return model - def init_optimizer(self, model): - optimizer_name = self.config.optimizer - optimizer_params = self.config.optimizer_params + @staticmethod + def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer: + optimizer_name = config.optimizer + optimizer_params = config.optimizer_params if optimizer_name.lower() == "radam": module = importlib.import_module("TTS.utils.radam") optimizer = getattr(module, "RAdam") else: optimizer = getattr(torch.optim, optimizer_name) - return optimizer(model.parameters(), - lr=self.config.lr, - **optimizer_params) + return optimizer(model.parameters(), lr=config.lr, **optimizer_params) - def init_character_processor(self): + @staticmethod + def get_character_processor(config: Coqpit) -> str: # setup custom characters if set in config file. # TODO: implement CharacterProcessor - if self.config.characters is not None: - symbols, phonemes = make_symbols( - **self.config.characters.to_dict()) + if config.characters is not None: + symbols, phonemes = make_symbols(**config.characters.to_dict()) else: - from TTS.tts.utils.text.symbols import symbols, phonemes - model_characters = phonemes if self.config.use_phonemes else symbols + from TTS.tts.utils.text.symbols import phonemes, symbols + model_characters = phonemes if config.use_phonemes else symbols return model_characters - def init_speaker_manager(self, restore_path: str = "", out_path: str = ""): + @staticmethod + def get_speaker_manager(config: Coqpit, + restore_path: str = "", + out_path: str = "", + data_train: List = []) -> SpeakerManager: speaker_manager = SpeakerManager() - if restore_path: - speakers_file = os.path.join(os.path.dirname(restore_path), - "speaker.json") - if not os.path.exists(speakers_file): - logging.info( - "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file" - ) - speakers_file = self.config.external_speaker_embedding_file + if config.use_speaker_embedding: + if restore_path: + speakers_file = os.path.join(os.path.dirname(restore_path), + "speaker.json") + if not os.path.exists(speakers_file): + print( + "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file" + ) + speakers_file = config.external_speaker_embedding_file - if self.config.use_external_speaker_embedding_file: - speaker_manager.load_x_vectors_file(speakers_file) + if config.use_external_speaker_embedding_file: + speaker_manager.load_x_vectors_file(speakers_file) + else: + speaker_manager.load_ids_file(speakers_file) + elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file: + speaker_manager.load_x_vectors_file( + config.external_speaker_embedding_file) else: - self.speaker_manage.load_speaker_mapping(speakers_file) - elif self.config.use_external_speaker_embedding_file and self.config.external_speaker_embedding_file: - speaker_manager.load_x_vectors_file( - self.config.external_speaker_embedding_file) - else: - speaker_manager.parse_speakers_from_items(self.data_train) - file_path = os.path.join(out_path, "speakers.json") - speaker_manager.save_ids_file(file_path) + speaker_manager.parse_speakers_from_items(data_train) + file_path = os.path.join(out_path, "speakers.json") + speaker_manager.save_ids_file(file_path) return speaker_manager - def init_scheduler(self, config, optimizer): + @staticmethod + def get_scheduler(config: Coqpit, + optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler: lr_scheduler = config.lr_scheduler lr_scheduler_params = config.lr_scheduler_params if lr_scheduler is None: @@ -213,17 +224,20 @@ class TrainerTTS: scheduler = getattr(torch.optim, lr_scheduler) return scheduler(optimizer, **lr_scheduler_params) - def init_criterion(self): - return setup_loss(self.config) + @staticmethod + def get_criterion(config: Coqpit) -> nn.Module: + return setup_loss(config) - def restore_model(self, - config, - restore_path, - model, - optimizer, - scaler=None): - logging.info(f" > Restoring from {os.path.basename(restore_path)}...") - checkpoint = torch.load(restore_path, map_location="cpu") + def restore_model( + self, + config: Coqpit, + restore_path: str, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scaler: torch.cuda.amp.GradScaler = None + ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: + print(" > Restoring from %s ..." % os.path.basename(restore_path)) + checkpoint = torch.load(restore_path) try: logging.info(" > Restoring Model...") model.load_state_dict(checkpoint["model"]) @@ -242,20 +256,20 @@ class TrainerTTS: for group in optimizer.param_groups: group["lr"] = self.config.lr - logging.info(" > Model restored from step %d" % checkpoint["step"], - flush=True) + print(" > Model restored from step %d" % checkpoint["step"], ) restore_step = checkpoint["step"] return model, optimizer, scaler, restore_step - def _setup_loader(self, r, ap, is_eval, data_items, verbose, - speaker_mapping): + def _get_loader(self, r: int, ap: AudioProcessor, is_eval: bool, + data_items: List, verbose: bool, + speaker_mapping: Union[Dict, List]) -> DataLoader: if is_eval and not self.config.run_eval: loader = None else: dataset = TTSDataset( outputs_per_step=r, text_cleaner=self.config.text_cleaner, - compute_linear_spec= 'tacotron' == self.config.model.lower(), + compute_linear_spec=self.config.model.lower() == "tacotron", meta_data=data_items, ap=ap, tp=self.config.characters, @@ -296,17 +310,19 @@ class TrainerTTS: ) return loader - def setup_train_dataloader(self, r, ap, data_items, verbose, - speaker_mapping): - return self._setup_loader(r, ap, False, data_items, verbose, - speaker_mapping) + def get_train_dataloader(self, r: int, ap: AudioProcessor, + data_items: List, verbose: bool, + speaker_mapping: Union[List, Dict]) -> DataLoader: + return self._get_loader(r, ap, False, data_items, verbose, + speaker_mapping) - def setup_eval_dataloder(self, r, ap, data_items, verbose, - speaker_mapping): - return self._setup_loader(r, ap, True, data_items, verbose, - speaker_mapping) + def get_eval_dataloder(self, r: int, ap: AudioProcessor, data_items: List, + verbose: bool, + speaker_mapping: Union[List, Dict]) -> DataLoader: + return self._get_loader(r, ap, True, data_items, verbose, + speaker_mapping) - def format_batch(self, batch): + def format_batch(self, batch: List) -> Dict: # setup input batch text_input = batch[0] text_lengths = batch[1] @@ -401,7 +417,8 @@ class TrainerTTS: "item_idx": item_idx } - def train_step(self, batch, batch_n_steps, step, loader_start_time): + def train_step(self, batch: Dict, batch_n_steps: int, step: int, + loader_start_time: float) -> Tuple[Dict, Dict]: self.on_train_step_start() step_start_time = time.time() @@ -515,7 +532,7 @@ class TrainerTTS: self.on_train_step_end() return outputs, loss_dict - def train_epoch(self): + def train_epoch(self) -> None: self.model.train() epoch_start_time = time.time() if self.use_cuda: @@ -541,7 +558,7 @@ class TrainerTTS: self.tb_logger.tb_model_weights(self.model, self.total_steps_done) - def eval_step(self, batch, step): + def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: with torch.no_grad(): step_start_time = time.time() @@ -572,17 +589,11 @@ class TrainerTTS: self.keep_avg_eval.avg_values) return outputs, loss_dict - def eval_epoch(self): + def eval_epoch(self) -> None: self.model.eval() - if self.use_cuda: - batch_num_steps = int( - len(self.train_loader.dataset) / - (self.config.batch_size * self.num_gpus)) - else: - batch_num_steps = int( - len(self.train_loader.dataset) / self.config.batch_size) self.c_logger.print_eval_start() loader_start_time = time.time() + batch = None for cur_step, batch in enumerate(self.eval_loader): # format data batch = self.format_batch(batch) @@ -597,8 +608,8 @@ class TrainerTTS: {"EvalAudio": eval_audios}, self.ap.sample_rate) - def test_run(self, ): - logging.info(" | > Synthesizing test sentences.") + def test_run(self, ) -> None: + print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -618,9 +629,11 @@ class TrainerTTS: do_trim_silence=False, ).values() - file_path = os.path.join(self.output_audio_path, str(self.total_steps_done)) + file_path = os.path.join(self.output_audio_path, + str(self.total_steps_done)) os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) + file_path = os.path.join(file_path, + "TestSentence_{}.wav".format(idx)) self.ap.save_wav(wav, file_path) test_audios["{}-audio".format(idx)] = wav test_figures["{}-prediction".format(idx)] = plot_spectrogram( @@ -629,16 +642,17 @@ class TrainerTTS: alignment, output_fig=False) self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, - self.config.audio["sample_rate"]) + self.config.audio["sample_rate"]) self.tb_logger.tb_test_figures(self.total_steps_done, test_figures) - def _get_cond_inputs(self): + def _get_cond_inputs(self) -> Dict: # setup speaker_id speaker_id = 0 if self.config.use_speaker_embedding else None # setup x_vector - x_vector = self.speaker_manager.get_x_vectors_by_speaker( - self.speaker_manager.speaker_ids[0] - ) if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding else None + x_vector = (self.speaker_manager.get_x_vectors_by_speaker( + self.speaker_manager.speaker_ids[0]) + if self.config.use_external_speaker_embedding_file + and self.config.use_speaker_embedding else None) # setup style_mel if self.config.has('gst_style_input'): style_wav = self.config.gst_style_input @@ -647,35 +661,40 @@ class TrainerTTS: if style_wav is None and 'use_gst' in self.config and self.config.use_gst: # inicialize GST with zero dict. style_wav = {} - print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") + print( + "WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!" + ) for i in range(self.config.gst["gst_num_style_tokens"]): style_wav[str(i)] = 0 - cond_inputs = {'speaker_id': speaker_id, 'style_wav': style_wav, 'x_vector': x_vector} + cond_inputs = { + "speaker_id": speaker_id, + "style_wav": style_wav, + "x_vector": x_vector + } return cond_inputs - def fit(self): + def fit(self) -> None: if self.restore_step != 0 or self.args.best_path: - logging.info(" > Restoring best loss from " - f"{os.path.basename(self.args.best_path)} ...") + print(" > Restoring best loss from " + f"{os.path.basename(self.args.best_path)} ...") self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"] - logging.info( - f" > Starting with loaded last best loss {self.best_loss}.") + print(f" > Starting with loaded last best loss {self.best_loss}.") # define data loaders - self.train_loader = self.setup_train_dataloader( + self.train_loader = self.get_train_dataloader( self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids) - self.eval_loader = self.setup_eval_dataloder( + self.eval_loader = (self.get_eval_dataloder( self.config.r, self.ap, self.data_train, verbose=True, - speaker_mapping=self.speaker_manager.speaker_ids - ) if self.config.run_eval else None + speaker_mapping=self.speaker_manager.speaker_ids) + if self.config.run_eval else None) self.total_steps_done = self.restore_step @@ -697,10 +716,10 @@ class TrainerTTS: self.save_best_model() self.on_epoch_end() - def save_best_model(self): + def save_best_model(self) -> None: self.best_loss = save_best_model( - self.keep_avg_eval['avg_loss'] - if self.keep_avg_eval else self.keep_avg_train['avg_loss'], + self.keep_avg_eval["avg_loss"] + if self.keep_avg_eval else self.keep_avg_train["avg_loss"], self.best_loss, self.model, self.optimizer, @@ -715,8 +734,16 @@ class TrainerTTS: if self.config.mixed_precision else None, ) - def on_epoch_start(self): - if hasattr(self.model, 'on_epoch_start'): + @staticmethod + def _setup_logger_config(log_file: str) -> None: + logging.basicConfig( + level=logging.INFO, + format="", + handlers=[logging.FileHandler(log_file), + logging.StreamHandler()]) + + def on_epoch_start(self) -> None: # pylint: disable=no-self-use + if hasattr(self.model, "on_epoch_start"): self.model.on_epoch_start(self) if hasattr(self.criterion, "on_epoch_start"): @@ -725,8 +752,8 @@ class TrainerTTS: if hasattr(self.optimizer, "on_epoch_start"): self.optimizer.on_epoch_start(self) - def on_epoch_end(self): - if hasattr(self.model, "on_epoch_start"): + def on_epoch_end(self) -> None: # pylint: disable=no-self-use + if hasattr(self.model, "on_epoch_end"): self.model.on_epoch_end(self) if hasattr(self.criterion, "on_epoch_end"): @@ -735,8 +762,8 @@ class TrainerTTS: if hasattr(self.optimizer, "on_epoch_end"): self.optimizer.on_epoch_end(self) - def on_train_step_start(self): - if hasattr(self.model, "on_epoch_start"): + def on_train_step_start(self) -> None: # pylint: disable=no-self-use + if hasattr(self.model, "on_train_step_start"): self.model.on_train_step_start(self) if hasattr(self.criterion, "on_train_step_start"): @@ -745,7 +772,7 @@ class TrainerTTS: if hasattr(self.optimizer, "on_train_step_start"): self.optimizer.on_train_step_start(self) - def on_train_step_end(self): + def on_train_step_end(self) -> None: # pylint: disable=no-self-use if hasattr(self.model, "on_train_step_end"): self.model.on_train_step_end(self)