mirror of https://github.com/coqui-ai/TTS.git
chore(api): add type hints
This commit is contained in:
parent
85dbb3b8b3
commit
a05177ce71
39
TTS/api.py
39
TTS/api.py
|
@ -1,3 +1,5 @@
|
||||||
|
"""Coqui TTS Python API."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -30,7 +32,7 @@ class TTS(nn.Module):
|
||||||
encoder_config_path: Optional[str] = None,
|
encoder_config_path: Optional[str] = None,
|
||||||
progress_bar: bool = True,
|
progress_bar: bool = True,
|
||||||
gpu: bool = False,
|
gpu: bool = False,
|
||||||
):
|
) -> None:
|
||||||
"""🐸TTS python interface that allows to load and use the released models.
|
"""🐸TTS python interface that allows to load and use the released models.
|
||||||
|
|
||||||
Example with a multi-speaker model:
|
Example with a multi-speaker model:
|
||||||
|
@ -97,17 +99,17 @@ class TTS(nn.Module):
|
||||||
self.load_tts_model_by_path(model_path, config_path, gpu=gpu)
|
self.load_tts_model_by_path(model_path, config_path, gpu=gpu)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def models(self):
|
def models(self) -> list[str]:
|
||||||
return self.manager.list_tts_models()
|
return self.manager.list_tts_models()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multi_speaker(self):
|
def is_multi_speaker(self) -> bool:
|
||||||
if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager:
|
if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager:
|
||||||
return self.synthesizer.tts_model.speaker_manager.num_speakers > 1
|
return self.synthesizer.tts_model.speaker_manager.num_speakers > 1
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multi_lingual(self):
|
def is_multi_lingual(self) -> bool:
|
||||||
# Not sure what sets this to None, but applied a fix to prevent crashing.
|
# Not sure what sets this to None, but applied a fix to prevent crashing.
|
||||||
if (
|
if (
|
||||||
isinstance(self.model_name, str)
|
isinstance(self.model_name, str)
|
||||||
|
@ -121,23 +123,23 @@ class TTS(nn.Module):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def speakers(self):
|
def speakers(self) -> list[str]:
|
||||||
if not self.is_multi_speaker:
|
if not self.is_multi_speaker:
|
||||||
return None
|
return None
|
||||||
return self.synthesizer.tts_model.speaker_manager.speaker_names
|
return self.synthesizer.tts_model.speaker_manager.speaker_names
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def languages(self):
|
def languages(self) -> list[str]:
|
||||||
if not self.is_multi_lingual:
|
if not self.is_multi_lingual:
|
||||||
return None
|
return None
|
||||||
return self.synthesizer.tts_model.language_manager.language_names
|
return self.synthesizer.tts_model.language_manager.language_names
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_models_file_path():
|
def get_models_file_path() -> Path:
|
||||||
return Path(__file__).parent / ".models.json"
|
return Path(__file__).parent / ".models.json"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_models():
|
def list_models() -> list[str]:
|
||||||
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
|
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
|
||||||
|
|
||||||
def download_model_by_name(
|
def download_model_by_name(
|
||||||
|
@ -159,7 +161,7 @@ class TTS(nn.Module):
|
||||||
self.vocoder_config_path = vocoder_config_path
|
self.vocoder_config_path = vocoder_config_path
|
||||||
return model_path, config_path, None
|
return model_path, config_path, None
|
||||||
|
|
||||||
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
|
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
|
||||||
"""Load one of the 🐸TTS models by name.
|
"""Load one of the 🐸TTS models by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -168,7 +170,7 @@ class TTS(nn.Module):
|
||||||
"""
|
"""
|
||||||
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
|
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
|
||||||
|
|
||||||
def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False):
|
def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False) -> None:
|
||||||
"""Load one of the voice conversion models by name.
|
"""Load one of the voice conversion models by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -181,7 +183,7 @@ class TTS(nn.Module):
|
||||||
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
|
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
|
def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
|
||||||
"""Load one of 🐸TTS models by name.
|
"""Load one of 🐸TTS models by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -235,11 +237,11 @@ class TTS(nn.Module):
|
||||||
|
|
||||||
def _check_arguments(
|
def _check_arguments(
|
||||||
self,
|
self,
|
||||||
speaker: str = None,
|
speaker: Optional[str] = None,
|
||||||
language: str = None,
|
language: Optional[str] = None,
|
||||||
speaker_wav: str = None,
|
speaker_wav: Optional[str] = None,
|
||||||
emotion: str = None,
|
emotion: Optional[str] = None,
|
||||||
speed: float = None,
|
speed: Optional[float] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Check if the arguments are valid for the model."""
|
"""Check if the arguments are valid for the model."""
|
||||||
|
@ -320,7 +322,7 @@ class TTS(nn.Module):
|
||||||
file_path: str = "output.wav",
|
file_path: str = "output.wav",
|
||||||
split_sentences: bool = True,
|
split_sentences: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> str:
|
||||||
"""Convert text to speech.
|
"""Convert text to speech.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -451,7 +453,7 @@ class TTS(nn.Module):
|
||||||
file_path: str = "output.wav",
|
file_path: str = "output.wav",
|
||||||
speaker: str = None,
|
speaker: str = None,
|
||||||
split_sentences: bool = True,
|
split_sentences: bool = True,
|
||||||
):
|
) -> str:
|
||||||
"""Convert text to speech with voice conversion and save to file.
|
"""Convert text to speech with voice conversion and save to file.
|
||||||
|
|
||||||
Check `tts_with_vc` for more details.
|
Check `tts_with_vc` for more details.
|
||||||
|
@ -479,3 +481,4 @@ class TTS(nn.Module):
|
||||||
text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences
|
text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences
|
||||||
)
|
)
|
||||||
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
|
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
|
||||||
|
return file_path
|
||||||
|
|
Loading…
Reference in New Issue