From 13334d507ca81d4dd02444b94349019b72e67d30 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 23 Jan 2023 13:45:45 +0100 Subject: [PATCH] Load model from path --- TTS/api.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/TTS/api.py b/TTS/api.py index 99c3e522..da571414 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -7,7 +7,7 @@ from TTS.utils.synthesizer import Synthesizer class TTS: """TODO: Add voice conversion and Capacitron support.""" - def __init__(self, model_name: str = None, progress_bar: bool = True, gpu=False): + def __init__(self, model_name: str = None, model_path:str = None, config_path:str=None, progress_bar: bool = True, gpu=False): """🐸TTS python interface that allows to load and use the released models. Example with a multi-speaker model: @@ -20,6 +20,10 @@ class TTS: >>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False) >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") + Example loading a model from a path: + >>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False) + >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") + Args: model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True. @@ -29,6 +33,8 @@ class TTS: self.synthesizer = None if model_name: self.load_model_by_name(model_name, gpu) + if model_path: + self.load_model_by_path(model_path, config_path, gpu) @property def models(self): @@ -90,6 +96,19 @@ class TTS: use_cuda=gpu, ) + def load_model_by_path(self, model_path: str, config_path: str, gpu: bool = False): + self.synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=None, + tts_languages_file=None, + vocoder_checkpoint=None, + vocoder_config=None, + encoder_checkpoint=None, + encoder_config=None, + use_cuda=gpu, + ) + def _check_arguments(self, speaker: str = None, language: str = None): if self.is_multi_speaker and speaker is None: raise ValueError("Model is multi-speaker but no speaker is provided.")