mirror of https://github.com/coqui-ai/TTS.git
feat(api): allow mixing TTS and vocoder model name and path
This commit is contained in:
parent
1a4e58d0ce
commit
85dbb3b8b3
38
TTS/api.py
38
TTS/api.py
|
@ -75,8 +75,12 @@ class TTS(nn.Module):
|
|||
self.synthesizer = None
|
||||
self.voice_converter = None
|
||||
self.model_name = ""
|
||||
|
||||
self.vocoder_path = vocoder_path
|
||||
self.vocoder_config_path = vocoder_config_path
|
||||
self.encoder_path = encoder_path
|
||||
self.encoder_config_path = encoder_config_path
|
||||
|
||||
if gpu:
|
||||
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
|
||||
|
||||
|
@ -90,9 +94,7 @@ class TTS(nn.Module):
|
|||
self.load_model_by_name(model_name, vocoder_name, gpu=gpu)
|
||||
|
||||
if model_path:
|
||||
self.load_tts_model_by_path(
|
||||
model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
|
||||
)
|
||||
self.load_tts_model_by_path(model_path, config_path, gpu=gpu)
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
|
@ -140,18 +142,22 @@ class TTS(nn.Module):
|
|||
|
||||
def download_model_by_name(
|
||||
self, model_name: str, vocoder_name: Optional[str] = None
|
||||
) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]:
|
||||
) -> tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
model_path, config_path, model_item = self.manager.download_model(model_name)
|
||||
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
|
||||
# return model directory if there are multiple files
|
||||
# we assume that the model knows how to load itself
|
||||
return None, None, None, None, model_path
|
||||
return None, None, model_path
|
||||
if model_item.get("default_vocoder") is None:
|
||||
return model_path, config_path, None, None, None
|
||||
return model_path, config_path, None
|
||||
if vocoder_name is None:
|
||||
vocoder_name = model_item["default_vocoder"]
|
||||
vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name)
|
||||
return model_path, config_path, vocoder_path, vocoder_config_path, None
|
||||
# A local vocoder model will take precedence if specified via vocoder_path
|
||||
if self.vocoder_path is None or self.vocoder_config_path is None:
|
||||
self.vocoder_path = vocoder_path
|
||||
self.vocoder_config_path = vocoder_config_path
|
||||
return model_path, config_path, None
|
||||
|
||||
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
|
||||
"""Load one of the 🐸TTS models by name.
|
||||
|
@ -170,7 +176,7 @@ class TTS(nn.Module):
|
|||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name)
|
||||
model_path, config_path, model_dir = self.download_model_by_name(model_name)
|
||||
self.voice_converter = Synthesizer(
|
||||
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
|
||||
)
|
||||
|
@ -187,9 +193,7 @@ class TTS(nn.Module):
|
|||
self.synthesizer = None
|
||||
self.model_name = model_name
|
||||
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
|
||||
model_name, vocoder_name
|
||||
)
|
||||
model_path, config_path, model_dir = self.download_model_by_name(model_name, vocoder_name)
|
||||
|
||||
# init synthesizer
|
||||
# None values are fetch from the model
|
||||
|
@ -198,17 +202,15 @@ class TTS(nn.Module):
|
|||
tts_config_path=config_path,
|
||||
tts_speakers_file=None,
|
||||
tts_languages_file=None,
|
||||
vocoder_checkpoint=vocoder_path,
|
||||
vocoder_config=vocoder_config_path,
|
||||
vocoder_checkpoint=self.vocoder_path,
|
||||
vocoder_config=self.vocoder_config_path,
|
||||
encoder_checkpoint=self.encoder_path,
|
||||
encoder_config=self.encoder_config_path,
|
||||
model_dir=model_dir,
|
||||
use_cuda=gpu,
|
||||
)
|
||||
|
||||
def load_tts_model_by_path(
|
||||
self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
|
||||
):
|
||||
def load_tts_model_by_path(self, model_path: str, config_path: str, *, gpu: bool = False) -> None:
|
||||
"""Load a model from a path.
|
||||
|
||||
Args:
|
||||
|
@ -224,8 +226,8 @@ class TTS(nn.Module):
|
|||
tts_config_path=config_path,
|
||||
tts_speakers_file=None,
|
||||
tts_languages_file=None,
|
||||
vocoder_checkpoint=vocoder_path,
|
||||
vocoder_config=vocoder_config,
|
||||
vocoder_checkpoint=self.vocoder_path,
|
||||
vocoder_config=self.vocoder_config_path,
|
||||
encoder_checkpoint=self.encoder_path,
|
||||
encoder_config=self.encoder_config_path,
|
||||
use_cuda=gpu,
|
||||
|
|
Loading…
Reference in New Issue