mirror of https://github.com/coqui-ai/TTS.git
feat(api): support specifying vocoders by name
This commit is contained in:
parent
5cfb4ecccd
commit
42ad9b00c6
29
TTS/api.py
29
TTS/api.py
|
@ -23,6 +23,7 @@ class TTS(nn.Module):
|
|||
*,
|
||||
model_path: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
vocoder_name: Optional[str] = None,
|
||||
vocoder_path: Optional[str] = None,
|
||||
vocoder_config_path: Optional[str] = None,
|
||||
progress_bar: bool = True,
|
||||
|
@ -58,6 +59,7 @@ class TTS(nn.Module):
|
|||
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
|
||||
model_path (str, optional): Path to the model checkpoint. Defaults to None.
|
||||
config_path (str, optional): Path to the model config. Defaults to None.
|
||||
vocoder_name (str, optional): Pre-trained vocoder to use. Defaults to None, i.e. using the default vocoder.
|
||||
vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
|
||||
vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
|
||||
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
|
||||
|
@ -74,11 +76,12 @@ class TTS(nn.Module):
|
|||
|
||||
if model_name is not None and len(model_name) > 0:
|
||||
if "tts_models" in model_name:
|
||||
self.load_tts_model_by_name(model_name, gpu)
|
||||
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
|
||||
elif "voice_conversion_models" in model_name:
|
||||
self.load_vc_model_by_name(model_name, gpu)
|
||||
self.load_vc_model_by_name(model_name, gpu=gpu)
|
||||
# To allow just TTS("xtts")
|
||||
else:
|
||||
self.load_model_by_name(model_name, gpu)
|
||||
self.load_model_by_name(model_name, vocoder_name, gpu=gpu)
|
||||
|
||||
if model_path:
|
||||
self.load_tts_model_by_path(
|
||||
|
@ -129,7 +132,9 @@ class TTS(nn.Module):
|
|||
def list_models():
|
||||
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
|
||||
|
||||
def download_model_by_name(self, model_name: str):
|
||||
def download_model_by_name(
|
||||
self, model_name: str, vocoder_name: Optional[str] = None
|
||||
) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]:
|
||||
model_path, config_path, model_item = self.manager.download_model(model_name)
|
||||
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
|
||||
# return model directory if there are multiple files
|
||||
|
@ -137,19 +142,21 @@ class TTS(nn.Module):
|
|||
return None, None, None, None, model_path
|
||||
if model_item.get("default_vocoder") is None:
|
||||
return model_path, config_path, None, None, None
|
||||
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
|
||||
if vocoder_name is None:
|
||||
vocoder_name = model_item["default_vocoder"]
|
||||
vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name)
|
||||
return model_path, config_path, vocoder_path, vocoder_config_path, None
|
||||
|
||||
def load_model_by_name(self, model_name: str, gpu: bool = False):
|
||||
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
|
||||
"""Load one of the 🐸TTS models by name.
|
||||
|
||||
Args:
|
||||
model_name (str): Model name to load. You can list models by ```tts.models```.
|
||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
self.load_tts_model_by_name(model_name, gpu)
|
||||
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
|
||||
|
||||
def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
|
||||
def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False):
|
||||
"""Load one of the voice conversion models by name.
|
||||
|
||||
Args:
|
||||
|
@ -162,7 +169,7 @@ class TTS(nn.Module):
|
|||
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
|
||||
)
|
||||
|
||||
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
|
||||
def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
|
||||
"""Load one of 🐸TTS models by name.
|
||||
|
||||
Args:
|
||||
|
@ -174,7 +181,9 @@ class TTS(nn.Module):
|
|||
self.synthesizer = None
|
||||
self.model_name = model_name
|
||||
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
|
||||
model_name, vocoder_name
|
||||
)
|
||||
|
||||
# init synthesizer
|
||||
# None values are fetch from the model
|
||||
|
|
Loading…
Reference in New Issue