From 42ad9b00c684666080c406840a8ccab5316734ce Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 4 Dec 2024 10:44:49 +0100 Subject: [PATCH] feat(api): support specifying vocoders by name --- TTS/api.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 62dab329..49b9a6b7 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -23,6 +23,7 @@ class TTS(nn.Module): *, model_path: Optional[str] = None, config_path: Optional[str] = None, + vocoder_name: Optional[str] = None, vocoder_path: Optional[str] = None, vocoder_config_path: Optional[str] = None, progress_bar: bool = True, @@ -58,6 +59,7 @@ class TTS(nn.Module): model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. model_path (str, optional): Path to the model checkpoint. Defaults to None. config_path (str, optional): Path to the model config. Defaults to None. + vocoder_name (str, optional): Pre-trained vocoder to use. Defaults to None, i.e. using the default vocoder. vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None. progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True. @@ -74,11 +76,12 @@ class TTS(nn.Module): if model_name is not None and len(model_name) > 0: if "tts_models" in model_name: - self.load_tts_model_by_name(model_name, gpu) + self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu) elif "voice_conversion_models" in model_name: - self.load_vc_model_by_name(model_name, gpu) + self.load_vc_model_by_name(model_name, gpu=gpu) + # To allow just TTS("xtts") else: - self.load_model_by_name(model_name, gpu) + self.load_model_by_name(model_name, vocoder_name, gpu=gpu) if model_path: self.load_tts_model_by_path( @@ -129,7 +132,9 @@ class TTS(nn.Module): def list_models(): return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models() - def download_model_by_name(self, model_name: str): + def download_model_by_name( + self, model_name: str, vocoder_name: Optional[str] = None + ) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]: model_path, config_path, model_item = self.manager.download_model(model_name) if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)): # return model directory if there are multiple files @@ -137,19 +142,21 @@ class TTS(nn.Module): return None, None, None, None, model_path if model_item.get("default_vocoder") is None: return model_path, config_path, None, None, None - vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"]) + if vocoder_name is None: + vocoder_name = model_item["default_vocoder"] + vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name) return model_path, config_path, vocoder_path, vocoder_config_path, None - def load_model_by_name(self, model_name: str, gpu: bool = False): + def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False): """Load one of the 🐸TTS models by name. Args: model_name (str): Model name to load. You can list models by ```tts.models```. gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ - self.load_tts_model_by_name(model_name, gpu) + self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu) - def load_vc_model_by_name(self, model_name: str, gpu: bool = False): + def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False): """Load one of the voice conversion models by name. Args: @@ -162,7 +169,7 @@ class TTS(nn.Module): vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu ) - def load_tts_model_by_name(self, model_name: str, gpu: bool = False): + def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False): """Load one of 🐸TTS models by name. Args: @@ -174,7 +181,9 @@ class TTS(nn.Module): self.synthesizer = None self.model_name = model_name - model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name) + model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name( + model_name, vocoder_name + ) # init synthesizer # None values are fetch from the model