From e0f621180f328eac461e9fee978073db3c7b8421 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 2 Dec 2024 22:34:19 +0100 Subject: [PATCH] refactor(bin.synthesize): use Python API for CLI --- TTS/api.py | 16 +++-- TTS/bin/synthesize.py | 125 +++++++++++---------------------- tests/zoo_tests/test_models.py | 33 ++++----- 3 files changed, 65 insertions(+), 109 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index be6141d3..83189482 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -109,7 +109,11 @@ class TTS(nn.Module): @property def is_multi_speaker(self) -> bool: - if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager: + if ( + self.synthesizer is not None + and hasattr(self.synthesizer.tts_model, "speaker_manager") + and self.synthesizer.tts_model.speaker_manager + ): return self.synthesizer.tts_model.speaker_manager.num_speakers > 1 return False @@ -123,7 +127,11 @@ class TTS(nn.Module): and ("xtts" in self.config.model or "languages" in self.config and len(self.config.languages) > 1) ): return True - if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager: + if ( + self.synthesizer is not None + and hasattr(self.synthesizer.tts_model, "language_manager") + and self.synthesizer.tts_model.language_manager + ): return self.synthesizer.tts_model.language_manager.num_languages > 1 return False @@ -306,10 +314,6 @@ class TTS(nn.Module): speaker_name=speaker, language_name=language, speaker_wav=speaker_wav, - reference_wav=None, - style_wav=None, - style_text=None, - reference_speaker_name=None, split_sentences=split_sentences, **kwargs, ) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 59ceb1db..885f6d6f 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -9,8 +9,6 @@ import sys from argparse import RawTextHelpFormatter # pylint: disable=redefined-outer-name, unused-argument -from pathlib import Path - from TTS.utils.generic_utils import ConsoleFormatter, setup_logger logger = logging.getLogger(__name__) @@ -312,7 +310,8 @@ def parse_args() -> argparse.Namespace: return args -def main(): +def main() -> None: + """Entry point for `tts` command line interface.""" setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) args = parse_args() @@ -320,12 +319,11 @@ def main(): with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout): # Late-import to make things load faster + from TTS.api import TTS from TTS.utils.manage import ModelManager - from TTS.utils.synthesizer import Synthesizer # load model manager - path = Path(__file__).parent / "../.models.json" - manager = ModelManager(path, progress_bar=args.progress_bar) + manager = ModelManager(models_file=TTS.get_models_file_path(), progress_bar=args.progress_bar) tts_path = None tts_config_path = None @@ -339,12 +337,12 @@ def main(): vc_config_path = None model_dir = None - # CASE1 #list : list pre-trained TTS models + # 1) List pre-trained TTS models if args.list_models: manager.list_models() sys.exit() - # CASE2 #info : model info for pre-trained TTS models + # 2) Info about pre-trained TTS models (without loading a model) if args.model_info_by_idx: model_query = args.model_info_by_idx manager.model_info_by_idx(model_query) @@ -355,91 +353,50 @@ def main(): manager.model_info_by_full_name(model_query_full_name) sys.exit() - # CASE3: load pre-trained model paths - if args.model_name is not None and not args.model_path: - model_path, config_path, model_item = manager.download_model(args.model_name) - # tts model - if model_item["model_type"] == "tts_models": - tts_path = model_path - tts_config_path = config_path - if args.vocoder_name is None and "default_vocoder" in model_item: - args.vocoder_name = model_item["default_vocoder"] - - # voice conversion model - if model_item["model_type"] == "voice_conversion_models": - vc_path = model_path - vc_config_path = config_path - - # tts model with multiple files to be loaded from the directory path - if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list): - model_dir = model_path - tts_path = None - tts_config_path = None - args.vocoder_name = None - - # load vocoder - if args.vocoder_name is not None and not args.vocoder_path: - vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) - - # CASE4: set custom model paths - if args.model_path is not None: - tts_path = args.model_path - tts_config_path = args.config_path - speakers_file_path = args.speakers_file_path - language_ids_file_path = args.language_ids_file_path - - if args.vocoder_path is not None: - vocoder_path = args.vocoder_path - vocoder_config_path = args.vocoder_config_path - - if args.encoder_path is not None: - encoder_path = args.encoder_path - encoder_config_path = args.encoder_config_path - + # 3) Load a model for further info or TTS/VC device = args.device if args.use_cuda: device = "cuda" - - # load models - synthesizer = Synthesizer( - tts_checkpoint=tts_path, - tts_config_path=tts_config_path, - tts_speakers_file=speakers_file_path, - tts_languages_file=language_ids_file_path, - vocoder_checkpoint=vocoder_path, - vocoder_config=vocoder_config_path, - encoder_checkpoint=encoder_path, - encoder_config=encoder_config_path, - vc_checkpoint=vc_path, - vc_config=vc_config_path, - model_dir=model_dir, - voice_dir=args.voice_dir, + # A local model will take precedence if specified via modeL_path + model_name = args.model_name if args.model_path is None else None + api = TTS( + model_name=model_name, + model_path=args.model_path, + config_path=args.config_path, + vocoder_name=args.vocoder_name, + vocoder_path=args.vocoder_path, + vocoder_config_path=args.vocoder_config_path, + encoder_path=args.encoder_path, + encoder_config_path=args.encoder_config_path, + speakers_file_path=args.speakers_file_path, + language_ids_file_path=args.language_ids_file_path, + progress_bar=args.progress_bar, ).to(device) # query speaker ids of a multi-speaker model. if args.list_speaker_idxs: - if synthesizer.tts_model.speaker_manager is None: + if not api.is_multi_speaker: logger.info("Model only has a single speaker.") return logger.info( "Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." ) - logger.info(list(synthesizer.tts_model.speaker_manager.name_to_id.keys())) + logger.info(api.speakers) return # query langauge ids of a multi-lingual model. if args.list_language_idxs: - if synthesizer.tts_model.language_manager is None: + if not api.is_multi_lingual: logger.info("Monolingual model.") return logger.info( "Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." ) - logger.info(synthesizer.tts_model.language_manager.name_to_id) + logger.info(api.languages) return # check the arguments against a multi-speaker model. - if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): + if api.is_multi_speaker and (not args.speaker_idx and not args.speaker_wav): logger.error( "Looks like you use a multi-speaker model. Define `--speaker_idx` to " "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`." @@ -450,31 +407,29 @@ def main(): if args.text: logger.info("Text: %s", args.text) - # kick it - if tts_path is not None: - wav = synthesizer.tts( - args.text, - speaker_name=args.speaker_idx, - language_name=args.language_idx, + if args.text is not None: + api.tts_to_file( + text=args.text, + speaker=args.speaker_idx, + language=args.language_idx, speaker_wav=args.speaker_wav, + pipe_out=pipe_out, + file_path=args.out_path, reference_wav=args.reference_wav, style_wav=args.capacitron_style_wav, style_text=args.capacitron_style_text, reference_speaker_name=args.reference_speaker_idx, + voice_dir=args.voice_dir, ) - elif vc_path is not None: - wav = synthesizer.voice_conversion( + logger.info("Saved TTS output to %s", args.out_path) + elif args.source_wav is not None and args.target_wav is not None: + api.voice_conversion_to_file( source_wav=args.source_wav, target_wav=args.target_wav, + file_path=args.out_path, + pipe_out=pipe_out, ) - elif model_dir is not None: - wav = synthesizer.tts( - args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav - ) - - # save the results - synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out) - logger.info("Saved output to %s", args.out_path) + logger.info("Saved VC output to %s", args.out_path) if __name__ == "__main__": diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index b9444239..f38880b5 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -34,30 +34,27 @@ def run_models(offset=0, step=1): # download and run the model speaker_files = glob.glob(local_download_dir + "/speaker*") language_files = glob.glob(local_download_dir + "/language*") - language_id = "" + speaker_arg = "" + language_arg = "" if len(speaker_files) > 0: # multi-speaker model if "speaker_ids" in speaker_files[0]: speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0]) elif "speakers" in speaker_files[0]: speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0]) - - # multi-lingual model - Assuming multi-lingual models are also multi-speaker - if len(language_files) > 0 and "language_ids" in language_files[0]: - language_manager = LanguageManager(language_ids_file_path=language_files[0]) - language_id = language_manager.language_names[0] - - speaker_id = list(speaker_manager.name_to_id.keys())[0] - run_cli( - f"tts --model_name {model_name} " - f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" --no-progress_bar' - ) - else: - # single-speaker model - run_cli( - f"tts --model_name {model_name} " - f'--text "This is an example." --out_path "{output_path}" --no-progress_bar' - ) + speakers = list(speaker_manager.name_to_id.keys()) + if len(speakers) > 1: + speaker_arg = f'--speaker_idx "{speakers[0]}"' + if len(language_files) > 0 and "language_ids" in language_files[0]: + # multi-lingual model + language_manager = LanguageManager(language_ids_file_path=language_files[0]) + languages = language_manager.language_names + if len(languages) > 1: + language_arg = f'--language_idx "{languages[0]}"' + run_cli( + f'tts --model_name {model_name} --text "This is an example." ' + f'--out_path "{output_path}" {speaker_arg} {language_arg} --no-progress_bar' + ) # remove downloaded models shutil.rmtree(local_download_dir) shutil.rmtree(get_user_data_dir("tts"))