feat(api): allow mixing TTS and vocoder model name and path

This commit is contained in:
Enno Hermann 2024-12-04 16:02:07 +01:00
parent 1a4e58d0ce
commit 85dbb3b8b3
1 changed files with 20 additions and 18 deletions

View File

@ -75,8 +75,12 @@ class TTS(nn.Module):
self.synthesizer = None self.synthesizer = None
self.voice_converter = None self.voice_converter = None
self.model_name = "" self.model_name = ""
self.vocoder_path = vocoder_path
self.vocoder_config_path = vocoder_config_path
self.encoder_path = encoder_path self.encoder_path = encoder_path
self.encoder_config_path = encoder_config_path self.encoder_config_path = encoder_config_path
if gpu: if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
@ -90,9 +94,7 @@ class TTS(nn.Module):
self.load_model_by_name(model_name, vocoder_name, gpu=gpu) self.load_model_by_name(model_name, vocoder_name, gpu=gpu)
if model_path: if model_path:
self.load_tts_model_by_path( self.load_tts_model_by_path(model_path, config_path, gpu=gpu)
model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
)
@property @property
def models(self): def models(self):
@ -140,18 +142,22 @@ class TTS(nn.Module):
def download_model_by_name( def download_model_by_name(
self, model_name: str, vocoder_name: Optional[str] = None self, model_name: str, vocoder_name: Optional[str] = None
) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]: ) -> tuple[Optional[str], Optional[str], Optional[str]]:
model_path, config_path, model_item = self.manager.download_model(model_name) model_path, config_path, model_item = self.manager.download_model(model_name)
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)): if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
# return model directory if there are multiple files # return model directory if there are multiple files
# we assume that the model knows how to load itself # we assume that the model knows how to load itself
return None, None, None, None, model_path return None, None, model_path
if model_item.get("default_vocoder") is None: if model_item.get("default_vocoder") is None:
return model_path, config_path, None, None, None return model_path, config_path, None
if vocoder_name is None: if vocoder_name is None:
vocoder_name = model_item["default_vocoder"] vocoder_name = model_item["default_vocoder"]
vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name) vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name)
return model_path, config_path, vocoder_path, vocoder_config_path, None # A local vocoder model will take precedence if specified via vocoder_path
if self.vocoder_path is None or self.vocoder_config_path is None:
self.vocoder_path = vocoder_path
self.vocoder_config_path = vocoder_config_path
return model_path, config_path, None
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False): def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
"""Load one of the 🐸TTS models by name. """Load one of the 🐸TTS models by name.
@ -170,7 +176,7 @@ class TTS(nn.Module):
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
""" """
self.model_name = model_name self.model_name = model_name
model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name) model_path, config_path, model_dir = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer( self.voice_converter = Synthesizer(
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
) )
@ -187,9 +193,7 @@ class TTS(nn.Module):
self.synthesizer = None self.synthesizer = None
self.model_name = model_name self.model_name = model_name
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name( model_path, config_path, model_dir = self.download_model_by_name(model_name, vocoder_name)
model_name, vocoder_name
)
# init synthesizer # init synthesizer
# None values are fetch from the model # None values are fetch from the model
@ -198,17 +202,15 @@ class TTS(nn.Module):
tts_config_path=config_path, tts_config_path=config_path,
tts_speakers_file=None, tts_speakers_file=None,
tts_languages_file=None, tts_languages_file=None,
vocoder_checkpoint=vocoder_path, vocoder_checkpoint=self.vocoder_path,
vocoder_config=vocoder_config_path, vocoder_config=self.vocoder_config_path,
encoder_checkpoint=self.encoder_path, encoder_checkpoint=self.encoder_path,
encoder_config=self.encoder_config_path, encoder_config=self.encoder_config_path,
model_dir=model_dir, model_dir=model_dir,
use_cuda=gpu, use_cuda=gpu,
) )
def load_tts_model_by_path( def load_tts_model_by_path(self, model_path: str, config_path: str, *, gpu: bool = False) -> None:
self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
):
"""Load a model from a path. """Load a model from a path.
Args: Args:
@ -224,8 +226,8 @@ class TTS(nn.Module):
tts_config_path=config_path, tts_config_path=config_path,
tts_speakers_file=None, tts_speakers_file=None,
tts_languages_file=None, tts_languages_file=None,
vocoder_checkpoint=vocoder_path, vocoder_checkpoint=self.vocoder_path,
vocoder_config=vocoder_config, vocoder_config=self.vocoder_config_path,
encoder_checkpoint=self.encoder_path, encoder_checkpoint=self.encoder_path,
encoder_config=self.encoder_config_path, encoder_config=self.encoder_config_path,
use_cuda=gpu, use_cuda=gpu,