mirror of https://github.com/coqui-ai/TTS.git
chore(api): add type hints
This commit is contained in:
parent
85dbb3b8b3
commit
a05177ce71
39
TTS/api.py
39
TTS/api.py
|
@ -1,3 +1,5 @@
|
|||
"""Coqui TTS Python API."""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
import warnings
|
||||
|
@ -30,7 +32,7 @@ class TTS(nn.Module):
|
|||
encoder_config_path: Optional[str] = None,
|
||||
progress_bar: bool = True,
|
||||
gpu: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""🐸TTS python interface that allows to load and use the released models.
|
||||
|
||||
Example with a multi-speaker model:
|
||||
|
@ -97,17 +99,17 @@ class TTS(nn.Module):
|
|||
self.load_tts_model_by_path(model_path, config_path, gpu=gpu)
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
def models(self) -> list[str]:
|
||||
return self.manager.list_tts_models()
|
||||
|
||||
@property
|
||||
def is_multi_speaker(self):
|
||||
def is_multi_speaker(self) -> bool:
|
||||
if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager:
|
||||
return self.synthesizer.tts_model.speaker_manager.num_speakers > 1
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_multi_lingual(self):
|
||||
def is_multi_lingual(self) -> bool:
|
||||
# Not sure what sets this to None, but applied a fix to prevent crashing.
|
||||
if (
|
||||
isinstance(self.model_name, str)
|
||||
|
@ -121,23 +123,23 @@ class TTS(nn.Module):
|
|||
return False
|
||||
|
||||
@property
|
||||
def speakers(self):
|
||||
def speakers(self) -> list[str]:
|
||||
if not self.is_multi_speaker:
|
||||
return None
|
||||
return self.synthesizer.tts_model.speaker_manager.speaker_names
|
||||
|
||||
@property
|
||||
def languages(self):
|
||||
def languages(self) -> list[str]:
|
||||
if not self.is_multi_lingual:
|
||||
return None
|
||||
return self.synthesizer.tts_model.language_manager.language_names
|
||||
|
||||
@staticmethod
|
||||
def get_models_file_path():
|
||||
def get_models_file_path() -> Path:
|
||||
return Path(__file__).parent / ".models.json"
|
||||
|
||||
@staticmethod
|
||||
def list_models():
|
||||
def list_models() -> list[str]:
|
||||
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
|
||||
|
||||
def download_model_by_name(
|
||||
|
@ -159,7 +161,7 @@ class TTS(nn.Module):
|
|||
self.vocoder_config_path = vocoder_config_path
|
||||
return model_path, config_path, None
|
||||
|
||||
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
|
||||
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
|
||||
"""Load one of the 🐸TTS models by name.
|
||||
|
||||
Args:
|
||||
|
@ -168,7 +170,7 @@ class TTS(nn.Module):
|
|||
"""
|
||||
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
|
||||
|
||||
def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False):
|
||||
def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False) -> None:
|
||||
"""Load one of the voice conversion models by name.
|
||||
|
||||
Args:
|
||||
|
@ -181,7 +183,7 @@ class TTS(nn.Module):
|
|||
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
|
||||
)
|
||||
|
||||
def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
|
||||
def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
|
||||
"""Load one of 🐸TTS models by name.
|
||||
|
||||
Args:
|
||||
|
@ -235,11 +237,11 @@ class TTS(nn.Module):
|
|||
|
||||
def _check_arguments(
|
||||
self,
|
||||
speaker: str = None,
|
||||
language: str = None,
|
||||
speaker_wav: str = None,
|
||||
emotion: str = None,
|
||||
speed: float = None,
|
||||
speaker: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
speaker_wav: Optional[str] = None,
|
||||
emotion: Optional[str] = None,
|
||||
speed: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Check if the arguments are valid for the model."""
|
||||
|
@ -320,7 +322,7 @@ class TTS(nn.Module):
|
|||
file_path: str = "output.wav",
|
||||
split_sentences: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> str:
|
||||
"""Convert text to speech.
|
||||
|
||||
Args:
|
||||
|
@ -451,7 +453,7 @@ class TTS(nn.Module):
|
|||
file_path: str = "output.wav",
|
||||
speaker: str = None,
|
||||
split_sentences: bool = True,
|
||||
):
|
||||
) -> str:
|
||||
"""Convert text to speech with voice conversion and save to file.
|
||||
|
||||
Check `tts_with_vc` for more details.
|
||||
|
@ -479,3 +481,4 @@ class TTS(nn.Module):
|
|||
text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences
|
||||
)
|
||||
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
|
||||
return file_path
|
||||
|
|
Loading…
Reference in New Issue