diff --git a/TTS/api.py b/TTS/api.py index d16012f8..90f167dc 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -30,6 +30,8 @@ class TTS(nn.Module): vocoder_config_path: Optional[str] = None, encoder_path: Optional[str] = None, encoder_config_path: Optional[str] = None, + speakers_file_path: Optional[str] = None, + language_ids_file_path: Optional[str] = None, progress_bar: bool = True, gpu: bool = False, ) -> None: @@ -68,8 +70,10 @@ class TTS(nn.Module): vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None. encoder_path: Path to speaker encoder checkpoint. Default to None. encoder_config_path: Path to speaker encoder config file. Defaults to None. - progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True. - gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + speakers_file_path: JSON file for multi-speaker model. Defaults to None. + language_ids_file_path: JSON file for multilingual model. Defaults to None + progress_bar (bool, optional): Whether to print a progress bar while downloading a model. Defaults to True. + gpu (bool, optional): Enable/disable GPU. Defaults to False. DEPRECATED, use TTS(...).to("cuda") """ super().__init__() self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar) @@ -82,6 +86,8 @@ class TTS(nn.Module): self.vocoder_config_path = vocoder_config_path self.encoder_path = encoder_path self.encoder_config_path = encoder_config_path + self.speakers_file_path = speakers_file_path + self.language_ids_file_path = language_ids_file_path if gpu: warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") @@ -226,8 +232,8 @@ class TTS(nn.Module): self.synthesizer = Synthesizer( tts_checkpoint=model_path, tts_config_path=config_path, - tts_speakers_file=None, - tts_languages_file=None, + tts_speakers_file=self.speakers_file_path, + tts_languages_file=self.language_ids_file_path, vocoder_checkpoint=self.vocoder_path, vocoder_config=self.vocoder_config_path, encoder_checkpoint=self.encoder_path,