From 2c38ef8441d5162cc6eb76d94625386fb5543bc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 5 Jun 2021 11:46:53 +0200 Subject: [PATCH] use get_speaker_manager in Trainer and save speakers.json file when needed --- TTS/trainer.py | 22 ++-------------------- TTS/tts/utils/speakers.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index f837ce7f..564c4c26 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -21,7 +21,7 @@ from TTS.tts.datasets import TTSDataset, load_meta_data from TTS.tts.layers import setup_loss from TTS.tts.models import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint -from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -186,25 +186,7 @@ class TrainerTTS: def get_speaker_manager( config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None ) -> SpeakerManager: - speaker_manager = SpeakerManager() - if restore_path: - speakers_file = os.path.join(os.path.dirname(restore_path), "speaker.json") - if not os.path.exists(speakers_file): - print( - "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file" - ) - speakers_file = config.external_speaker_embedding_file - - if config.use_external_speaker_embedding_file: - speaker_manager.load_d_vectors_file(speakers_file) - else: - speaker_manager.load_ids_file(speakers_file) - elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file: - speaker_manager.load_d_vectors_file(config.external_speaker_embedding_file) - else: - speaker_manager.parse_speakers_from_items(data_train) - file_path = os.path.join(out_path, "speakers.json") - speaker_manager.save_ids_file(file_path) + speaker_manager = get_speaker_manager(config, restore_path, data_train, out_path) return speaker_manager @staticmethod diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 546d483d..0f43bf97 100755 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -34,16 +34,16 @@ def save_speaker_mapping(out_path, speaker_mapping): json.dump(speaker_mapping, f, indent=4) -def get_speaker_manager(c, args, meta_data_train): +def get_speaker_manager(c, restore_path, meta_data_train, out_path=None): """Inititalize and return a `SpeakerManager` based on config values""" speaker_manager = SpeakerManager() if c.use_speaker_embedding: speaker_manager.set_speaker_ids_from_data(meta_data_train) - if args.restore_path: + if restore_path: # restoring speaker manager from a previous run. if c.use_external_speaker_embedding_file: # restore speaker manager with the embedding file - speakers_file = os.path.dirname(args.restore_path) + speakers_file = os.path.dirname(restore_path) if not os.path.exists(speakers_file): print( "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file" @@ -55,7 +55,7 @@ def get_speaker_manager(c, args, meta_data_train): speaker_manager.load_d_vectors_file(c.external_speaker_embedding_file) speaker_manager.set_d_vectors_from_file(speakers_file) elif not c.use_external_speaker_embedding_file: # restor speaker manager with speaker ID file. - speakers_file = os.path.dirname(args.restore_path) + speakers_file = os.path.dirname(restore_path) speaker_ids_from_data = speaker_manager.speaker_ids speaker_manager.set_speaker_ids_from_file(speakers_file) assert all( @@ -73,6 +73,14 @@ def get_speaker_manager(c, args, meta_data_train): speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids) ) ) + # save file if path is defined + if out_path: + out_file_path = os.path.join(out_path, "speaker.json") + print(" > Saving `speaker.json` to {out_file_path}.") + if c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: + speaker_manager.save_d_vectors_to_file(out_file_path) + else: + speaker_manager.save_speaker_ids_to_file(out_file_path) return speaker_manager