diff --git a/TTS/trainer.py b/TTS/trainer.py index 6087f1bc..63b9cd42 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -150,7 +150,7 @@ class TrainerTTS: # count model size num_params = count_parameters(self.model) - logging.info("\n > Model has {} parameters".format(num_params)) + print("\n > Model has {} parameters".format(num_params)) @staticmethod def get_model(num_chars: int, num_speakers: int, config: Coqpit, @@ -186,7 +186,6 @@ class TrainerTTS: out_path: str = "", data_train: List = []) -> SpeakerManager: speaker_manager = SpeakerManager() - if config.use_speaker_embedding: if restore_path: speakers_file = os.path.join(os.path.dirname(restore_path), "speaker.json") @@ -196,16 +195,6 @@ class TrainerTTS: ) speakers_file = config.external_speaker_embedding_file - if config.use_external_speaker_embedding_file: - speaker_manager.load_x_vectors_file(speakers_file) - else: - speaker_manager.load_ids_file(speakers_file) - elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file: - speaker_manager.load_x_vectors_file( - config.external_speaker_embedding_file) - else: - speaker_manager.parse_speakers_from_items(data_train) - file_path = os.path.join(out_path, "speakers.json") speaker_manager.save_ids_file(file_path) return speaker_manager @@ -238,15 +227,15 @@ class TrainerTTS: print(" > Restoring from %s ..." % os.path.basename(restore_path)) checkpoint = torch.load(restore_path) try: - logging.info(" > Restoring Model...") + print(" > Restoring Model...") model.load_state_dict(checkpoint["model"]) - logging.info(" > Restoring Optimizer...") + print(" > Restoring Optimizer...") optimizer.load_state_dict(checkpoint["optimizer"]) if "scaler" in checkpoint and config.mixed_precision: - logging.info(" > Restoring AMP Scaler...") + print(" > Restoring AMP Scaler...") scaler.load_state_dict(checkpoint["scaler"]) except (KeyError, RuntimeError): - logging.info(" > Partial model initialization...") + print(" > Partial model initialization...") model_dict = model.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model"], config) model.load_state_dict(model_dict)