diff --git a/TTS/api.py b/TTS/api.py index 7ca79405..86593f3f 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -75,8 +75,12 @@ class TTS(nn.Module): self.synthesizer = None self.voice_converter = None self.model_name = "" + + self.vocoder_path = vocoder_path + self.vocoder_config_path = vocoder_config_path self.encoder_path = encoder_path self.encoder_config_path = encoder_config_path + if gpu: warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") @@ -90,9 +94,7 @@ class TTS(nn.Module): self.load_model_by_name(model_name, vocoder_name, gpu=gpu) if model_path: - self.load_tts_model_by_path( - model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu - ) + self.load_tts_model_by_path(model_path, config_path, gpu=gpu) @property def models(self): @@ -140,18 +142,22 @@ class TTS(nn.Module): def download_model_by_name( self, model_name: str, vocoder_name: Optional[str] = None - ) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]: + ) -> tuple[Optional[str], Optional[str], Optional[str]]: model_path, config_path, model_item = self.manager.download_model(model_name) if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)): # return model directory if there are multiple files # we assume that the model knows how to load itself - return None, None, None, None, model_path + return None, None, model_path if model_item.get("default_vocoder") is None: - return model_path, config_path, None, None, None + return model_path, config_path, None if vocoder_name is None: vocoder_name = model_item["default_vocoder"] vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name) - return model_path, config_path, vocoder_path, vocoder_config_path, None + # A local vocoder model will take precedence if specified via vocoder_path + if self.vocoder_path is None or self.vocoder_config_path is None: + self.vocoder_path = vocoder_path + self.vocoder_config_path = vocoder_config_path + return model_path, config_path, None def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False): """Load one of the 🐸TTS models by name. @@ -170,7 +176,7 @@ class TTS(nn.Module): gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ self.model_name = model_name - model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name) + model_path, config_path, model_dir = self.download_model_by_name(model_name) self.voice_converter = Synthesizer( vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu ) @@ -187,9 +193,7 @@ class TTS(nn.Module): self.synthesizer = None self.model_name = model_name - model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name( - model_name, vocoder_name - ) + model_path, config_path, model_dir = self.download_model_by_name(model_name, vocoder_name) # init synthesizer # None values are fetch from the model @@ -198,17 +202,15 @@ class TTS(nn.Module): tts_config_path=config_path, tts_speakers_file=None, tts_languages_file=None, - vocoder_checkpoint=vocoder_path, - vocoder_config=vocoder_config_path, + vocoder_checkpoint=self.vocoder_path, + vocoder_config=self.vocoder_config_path, encoder_checkpoint=self.encoder_path, encoder_config=self.encoder_config_path, model_dir=model_dir, use_cuda=gpu, ) - def load_tts_model_by_path( - self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False - ): + def load_tts_model_by_path(self, model_path: str, config_path: str, *, gpu: bool = False) -> None: """Load a model from a path. Args: @@ -224,8 +226,8 @@ class TTS(nn.Module): tts_config_path=config_path, tts_speakers_file=None, tts_languages_file=None, - vocoder_checkpoint=vocoder_path, - vocoder_config=vocoder_config, + vocoder_checkpoint=self.vocoder_path, + vocoder_config=self.vocoder_config_path, encoder_checkpoint=self.encoder_path, encoder_config=self.encoder_config_path, use_cuda=gpu,