diff --git a/TTS/api.py b/TTS/api.py index ed7e6e6b..22b81ba4 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -12,6 +12,8 @@ class TTS: model_name: str = None, model_path: str = None, config_path: str = None, + vocoder_path: str = None, + vocoder_config_path: str = None, progress_bar: bool = True, gpu=False, ): @@ -33,6 +35,10 @@ class TTS: Args: model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. + model_path (str, optional): Path to the model checkpoint. Defaults to None. + config_path (str, optional): Path to the model config. Defaults to None. + vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. + vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None. progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True. gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ @@ -41,7 +47,9 @@ class TTS: if model_name: self.load_model_by_name(model_name, gpu) if model_path: - self.load_model_by_path(model_path, config_path, gpu) + self.load_model_by_path( + model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu + ) @property def models(self): @@ -89,6 +97,14 @@ class TTS: def load_model_by_name(self, model_name: str, gpu: bool = False): model_path, config_path, vocoder_path, vocoder_config_path = self.download_model_by_name(model_name) + """ Load one of 🐸TTS models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + + TODO: Add tests + """ # init synthesizer # None values are fetch from the model self.synthesizer = Synthesizer( @@ -103,14 +119,26 @@ class TTS: use_cuda=gpu, ) - def load_model_by_path(self, model_path: str, config_path: str, gpu: bool = False): + def load_model_by_path( + self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False + ): + """Load a model from a path. + + Args: + model_path (str): Path to the model checkpoint. + config_path (str): Path to the model config. + vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. + vocoder_config (str, optional): Path to the vocoder config. Defaults to None. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + self.synthesizer = Synthesizer( tts_checkpoint=model_path, tts_config_path=config_path, tts_speakers_file=None, tts_languages_file=None, - vocoder_checkpoint=None, - vocoder_config=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config, encoder_checkpoint=None, encoder_config=None, use_cuda=gpu,