mirror of https://github.com/coqui-ai/TTS.git
feat(api): support passing a custom speaker encoder by path
This commit is contained in:
parent
5daed879e0
commit
1a4e58d0ce
14
TTS/api.py
14
TTS/api.py
|
@ -26,6 +26,8 @@ class TTS(nn.Module):
|
||||||
vocoder_name: Optional[str] = None,
|
vocoder_name: Optional[str] = None,
|
||||||
vocoder_path: Optional[str] = None,
|
vocoder_path: Optional[str] = None,
|
||||||
vocoder_config_path: Optional[str] = None,
|
vocoder_config_path: Optional[str] = None,
|
||||||
|
encoder_path: Optional[str] = None,
|
||||||
|
encoder_config_path: Optional[str] = None,
|
||||||
progress_bar: bool = True,
|
progress_bar: bool = True,
|
||||||
gpu: bool = False,
|
gpu: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -62,6 +64,8 @@ class TTS(nn.Module):
|
||||||
vocoder_name (str, optional): Pre-trained vocoder to use. Defaults to None, i.e. using the default vocoder.
|
vocoder_name (str, optional): Pre-trained vocoder to use. Defaults to None, i.e. using the default vocoder.
|
||||||
vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
|
vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
|
||||||
vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
|
vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
|
||||||
|
encoder_path: Path to speaker encoder checkpoint. Default to None.
|
||||||
|
encoder_config_path: Path to speaker encoder config file. Defaults to None.
|
||||||
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
|
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
|
||||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
@ -71,6 +75,8 @@ class TTS(nn.Module):
|
||||||
self.synthesizer = None
|
self.synthesizer = None
|
||||||
self.voice_converter = None
|
self.voice_converter = None
|
||||||
self.model_name = ""
|
self.model_name = ""
|
||||||
|
self.encoder_path = encoder_path
|
||||||
|
self.encoder_config_path = encoder_config_path
|
||||||
if gpu:
|
if gpu:
|
||||||
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
|
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
|
||||||
|
|
||||||
|
@ -194,8 +200,8 @@ class TTS(nn.Module):
|
||||||
tts_languages_file=None,
|
tts_languages_file=None,
|
||||||
vocoder_checkpoint=vocoder_path,
|
vocoder_checkpoint=vocoder_path,
|
||||||
vocoder_config=vocoder_config_path,
|
vocoder_config=vocoder_config_path,
|
||||||
encoder_checkpoint=None,
|
encoder_checkpoint=self.encoder_path,
|
||||||
encoder_config=None,
|
encoder_config=self.encoder_config_path,
|
||||||
model_dir=model_dir,
|
model_dir=model_dir,
|
||||||
use_cuda=gpu,
|
use_cuda=gpu,
|
||||||
)
|
)
|
||||||
|
@ -220,8 +226,8 @@ class TTS(nn.Module):
|
||||||
tts_languages_file=None,
|
tts_languages_file=None,
|
||||||
vocoder_checkpoint=vocoder_path,
|
vocoder_checkpoint=vocoder_path,
|
||||||
vocoder_config=vocoder_config,
|
vocoder_config=vocoder_config,
|
||||||
encoder_checkpoint=None,
|
encoder_checkpoint=self.encoder_path,
|
||||||
encoder_config=None,
|
encoder_config=self.encoder_config_path,
|
||||||
use_cuda=gpu,
|
use_cuda=gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue