feat(api): allow mixing TTS and vocoder model name and path

This commit is contained in:
Enno Hermann 2024-12-04 16:02:07 +01:00
parent 1a4e58d0ce
commit 85dbb3b8b3
1 changed files with 20 additions and 18 deletions

View File

@ -75,8 +75,12 @@ class TTS(nn.Module):
self.synthesizer = None
self.voice_converter = None
self.model_name = ""
self.vocoder_path = vocoder_path
self.vocoder_config_path = vocoder_config_path
self.encoder_path = encoder_path
self.encoder_config_path = encoder_config_path
if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
@ -90,9 +94,7 @@ class TTS(nn.Module):
self.load_model_by_name(model_name, vocoder_name, gpu=gpu)
if model_path:
self.load_tts_model_by_path(
model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
)
self.load_tts_model_by_path(model_path, config_path, gpu=gpu)
@property
def models(self):
@ -140,18 +142,22 @@ class TTS(nn.Module):
def download_model_by_name(
self, model_name: str, vocoder_name: Optional[str] = None
) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]:
) -> tuple[Optional[str], Optional[str], Optional[str]]:
model_path, config_path, model_item = self.manager.download_model(model_name)
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
# return model directory if there are multiple files
# we assume that the model knows how to load itself
return None, None, None, None, model_path
return None, None, model_path
if model_item.get("default_vocoder") is None:
return model_path, config_path, None, None, None
return model_path, config_path, None
if vocoder_name is None:
vocoder_name = model_item["default_vocoder"]
vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name)
return model_path, config_path, vocoder_path, vocoder_config_path, None
# A local vocoder model will take precedence if specified via vocoder_path
if self.vocoder_path is None or self.vocoder_config_path is None:
self.vocoder_path = vocoder_path
self.vocoder_config_path = vocoder_config_path
return model_path, config_path, None
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
"""Load one of the 🐸TTS models by name.
@ -170,7 +176,7 @@ class TTS(nn.Module):
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.model_name = model_name
model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name)
model_path, config_path, model_dir = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer(
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
)
@ -187,9 +193,7 @@ class TTS(nn.Module):
self.synthesizer = None
self.model_name = model_name
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
model_name, vocoder_name
)
model_path, config_path, model_dir = self.download_model_by_name(model_name, vocoder_name)
# init synthesizer
# None values are fetch from the model
@ -198,17 +202,15 @@ class TTS(nn.Module):
tts_config_path=config_path,
tts_speakers_file=None,
tts_languages_file=None,
vocoder_checkpoint=vocoder_path,
vocoder_config=vocoder_config_path,
vocoder_checkpoint=self.vocoder_path,
vocoder_config=self.vocoder_config_path,
encoder_checkpoint=self.encoder_path,
encoder_config=self.encoder_config_path,
model_dir=model_dir,
use_cuda=gpu,
)
def load_tts_model_by_path(
self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
):
def load_tts_model_by_path(self, model_path: str, config_path: str, *, gpu: bool = False) -> None:
"""Load a model from a path.
Args:
@ -224,8 +226,8 @@ class TTS(nn.Module):
tts_config_path=config_path,
tts_speakers_file=None,
tts_languages_file=None,
vocoder_checkpoint=vocoder_path,
vocoder_config=vocoder_config,
vocoder_checkpoint=self.vocoder_path,
vocoder_config=self.vocoder_config_path,
encoder_checkpoint=self.encoder_path,
encoder_config=self.encoder_config_path,
use_cuda=gpu,