feat(api): support specifying vocoders by name

This commit is contained in:
Enno Hermann 2024-12-04 10:44:49 +01:00
parent 5cfb4ecccd
commit 42ad9b00c6
1 changed files with 19 additions and 10 deletions

View File

@ -23,6 +23,7 @@ class TTS(nn.Module):
*,
model_path: Optional[str] = None,
config_path: Optional[str] = None,
vocoder_name: Optional[str] = None,
vocoder_path: Optional[str] = None,
vocoder_config_path: Optional[str] = None,
progress_bar: bool = True,
@ -58,6 +59,7 @@ class TTS(nn.Module):
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
model_path (str, optional): Path to the model checkpoint. Defaults to None.
config_path (str, optional): Path to the model config. Defaults to None.
vocoder_name (str, optional): Pre-trained vocoder to use. Defaults to None, i.e. using the default vocoder.
vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
@ -74,11 +76,12 @@ class TTS(nn.Module):
if model_name is not None and len(model_name) > 0:
if "tts_models" in model_name:
self.load_tts_model_by_name(model_name, gpu)
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
elif "voice_conversion_models" in model_name:
self.load_vc_model_by_name(model_name, gpu)
self.load_vc_model_by_name(model_name, gpu=gpu)
# To allow just TTS("xtts")
else:
self.load_model_by_name(model_name, gpu)
self.load_model_by_name(model_name, vocoder_name, gpu=gpu)
if model_path:
self.load_tts_model_by_path(
@ -129,7 +132,9 @@ class TTS(nn.Module):
def list_models():
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
def download_model_by_name(self, model_name: str):
def download_model_by_name(
self, model_name: str, vocoder_name: Optional[str] = None
) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]:
model_path, config_path, model_item = self.manager.download_model(model_name)
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
# return model directory if there are multiple files
@ -137,19 +142,21 @@ class TTS(nn.Module):
return None, None, None, None, model_path
if model_item.get("default_vocoder") is None:
return model_path, config_path, None, None, None
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
if vocoder_name is None:
vocoder_name = model_item["default_vocoder"]
vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name)
return model_path, config_path, vocoder_path, vocoder_config_path, None
def load_model_by_name(self, model_name: str, gpu: bool = False):
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
"""Load one of the 🐸TTS models by name.
Args:
model_name (str): Model name to load. You can list models by ```tts.models```.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.load_tts_model_by_name(model_name, gpu)
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False):
"""Load one of the voice conversion models by name.
Args:
@ -162,7 +169,7 @@ class TTS(nn.Module):
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
)
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
"""Load one of 🐸TTS models by name.
Args:
@ -174,7 +181,9 @@ class TTS(nn.Module):
self.synthesizer = None
self.model_name = model_name
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
model_name, vocoder_name
)
# init synthesizer
# None values are fetch from the model