diff --git a/TTS/api.py b/TTS/api.py index 49b9a6b7..7ca79405 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -26,6 +26,8 @@ class TTS(nn.Module): vocoder_name: Optional[str] = None, vocoder_path: Optional[str] = None, vocoder_config_path: Optional[str] = None, + encoder_path: Optional[str] = None, + encoder_config_path: Optional[str] = None, progress_bar: bool = True, gpu: bool = False, ): @@ -62,6 +64,8 @@ class TTS(nn.Module): vocoder_name (str, optional): Pre-trained vocoder to use. Defaults to None, i.e. using the default vocoder. vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None. + encoder_path: Path to speaker encoder checkpoint. Default to None. + encoder_config_path: Path to speaker encoder config file. Defaults to None. progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True. gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ @@ -71,6 +75,8 @@ class TTS(nn.Module): self.synthesizer = None self.voice_converter = None self.model_name = "" + self.encoder_path = encoder_path + self.encoder_config_path = encoder_config_path if gpu: warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") @@ -194,8 +200,8 @@ class TTS(nn.Module): tts_languages_file=None, vocoder_checkpoint=vocoder_path, vocoder_config=vocoder_config_path, - encoder_checkpoint=None, - encoder_config=None, + encoder_checkpoint=self.encoder_path, + encoder_config=self.encoder_config_path, model_dir=model_dir, use_cuda=gpu, ) @@ -220,8 +226,8 @@ class TTS(nn.Module): tts_languages_file=None, vocoder_checkpoint=vocoder_path, vocoder_config=vocoder_config, - encoder_checkpoint=None, - encoder_config=None, + encoder_checkpoint=self.encoder_path, + encoder_config=self.encoder_config_path, use_cuda=gpu, )