chore(api): add type hints

This commit is contained in:
Enno Hermann 2024-12-04 16:11:43 +01:00
parent 85dbb3b8b3
commit a05177ce71
1 changed files with 21 additions and 18 deletions

View File

@ -1,3 +1,5 @@
"""Coqui TTS Python API."""
import logging
import tempfile
import warnings
@ -30,7 +32,7 @@ class TTS(nn.Module):
encoder_config_path: Optional[str] = None,
progress_bar: bool = True,
gpu: bool = False,
):
) -> None:
"""🐸TTS python interface that allows to load and use the released models.
Example with a multi-speaker model:
@ -97,17 +99,17 @@ class TTS(nn.Module):
self.load_tts_model_by_path(model_path, config_path, gpu=gpu)
@property
def models(self):
def models(self) -> list[str]:
return self.manager.list_tts_models()
@property
def is_multi_speaker(self):
def is_multi_speaker(self) -> bool:
if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager:
return self.synthesizer.tts_model.speaker_manager.num_speakers > 1
return False
@property
def is_multi_lingual(self):
def is_multi_lingual(self) -> bool:
# Not sure what sets this to None, but applied a fix to prevent crashing.
if (
isinstance(self.model_name, str)
@ -121,23 +123,23 @@ class TTS(nn.Module):
return False
@property
def speakers(self):
def speakers(self) -> list[str]:
if not self.is_multi_speaker:
return None
return self.synthesizer.tts_model.speaker_manager.speaker_names
@property
def languages(self):
def languages(self) -> list[str]:
if not self.is_multi_lingual:
return None
return self.synthesizer.tts_model.language_manager.language_names
@staticmethod
def get_models_file_path():
def get_models_file_path() -> Path:
return Path(__file__).parent / ".models.json"
@staticmethod
def list_models():
def list_models() -> list[str]:
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
def download_model_by_name(
@ -159,7 +161,7 @@ class TTS(nn.Module):
self.vocoder_config_path = vocoder_config_path
return model_path, config_path, None
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
"""Load one of the 🐸TTS models by name.
Args:
@ -168,7 +170,7 @@ class TTS(nn.Module):
"""
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False):
def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False) -> None:
"""Load one of the voice conversion models by name.
Args:
@ -181,7 +183,7 @@ class TTS(nn.Module):
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
)
def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False):
def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
"""Load one of 🐸TTS models by name.
Args:
@ -235,11 +237,11 @@ class TTS(nn.Module):
def _check_arguments(
self,
speaker: str = None,
language: str = None,
speaker_wav: str = None,
emotion: str = None,
speed: float = None,
speaker: Optional[str] = None,
language: Optional[str] = None,
speaker_wav: Optional[str] = None,
emotion: Optional[str] = None,
speed: Optional[float] = None,
**kwargs,
) -> None:
"""Check if the arguments are valid for the model."""
@ -320,7 +322,7 @@ class TTS(nn.Module):
file_path: str = "output.wav",
split_sentences: bool = True,
**kwargs,
):
) -> str:
"""Convert text to speech.
Args:
@ -451,7 +453,7 @@ class TTS(nn.Module):
file_path: str = "output.wav",
speaker: str = None,
split_sentences: bool = True,
):
) -> str:
"""Convert text to speech with voice conversion and save to file.
Check `tts_with_vc` for more details.
@ -479,3 +481,4 @@ class TTS(nn.Module):
text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences
)
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
return file_path