From a05177ce713a73112b40cb283ad9d3328a08f7cc Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 4 Dec 2024 16:11:43 +0100 Subject: [PATCH] chore(api): add type hints --- TTS/api.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 86593f3f..d16012f8 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -1,3 +1,5 @@ +"""Coqui TTS Python API.""" + import logging import tempfile import warnings @@ -30,7 +32,7 @@ class TTS(nn.Module): encoder_config_path: Optional[str] = None, progress_bar: bool = True, gpu: bool = False, - ): + ) -> None: """🐸TTS python interface that allows to load and use the released models. Example with a multi-speaker model: @@ -97,17 +99,17 @@ class TTS(nn.Module): self.load_tts_model_by_path(model_path, config_path, gpu=gpu) @property - def models(self): + def models(self) -> list[str]: return self.manager.list_tts_models() @property - def is_multi_speaker(self): + def is_multi_speaker(self) -> bool: if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager: return self.synthesizer.tts_model.speaker_manager.num_speakers > 1 return False @property - def is_multi_lingual(self): + def is_multi_lingual(self) -> bool: # Not sure what sets this to None, but applied a fix to prevent crashing. if ( isinstance(self.model_name, str) @@ -121,23 +123,23 @@ class TTS(nn.Module): return False @property - def speakers(self): + def speakers(self) -> list[str]: if not self.is_multi_speaker: return None return self.synthesizer.tts_model.speaker_manager.speaker_names @property - def languages(self): + def languages(self) -> list[str]: if not self.is_multi_lingual: return None return self.synthesizer.tts_model.language_manager.language_names @staticmethod - def get_models_file_path(): + def get_models_file_path() -> Path: return Path(__file__).parent / ".models.json" @staticmethod - def list_models(): + def list_models() -> list[str]: return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models() def download_model_by_name( @@ -159,7 +161,7 @@ class TTS(nn.Module): self.vocoder_config_path = vocoder_config_path return model_path, config_path, None - def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False): + def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None: """Load one of the 🐸TTS models by name. Args: @@ -168,7 +170,7 @@ class TTS(nn.Module): """ self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu) - def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False): + def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False) -> None: """Load one of the voice conversion models by name. Args: @@ -181,7 +183,7 @@ class TTS(nn.Module): vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu ) - def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False): + def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None: """Load one of 🐸TTS models by name. Args: @@ -235,11 +237,11 @@ class TTS(nn.Module): def _check_arguments( self, - speaker: str = None, - language: str = None, - speaker_wav: str = None, - emotion: str = None, - speed: float = None, + speaker: Optional[str] = None, + language: Optional[str] = None, + speaker_wav: Optional[str] = None, + emotion: Optional[str] = None, + speed: Optional[float] = None, **kwargs, ) -> None: """Check if the arguments are valid for the model.""" @@ -320,7 +322,7 @@ class TTS(nn.Module): file_path: str = "output.wav", split_sentences: bool = True, **kwargs, - ): + ) -> str: """Convert text to speech. Args: @@ -451,7 +453,7 @@ class TTS(nn.Module): file_path: str = "output.wav", speaker: str = None, split_sentences: bool = True, - ): + ) -> str: """Convert text to speech with voice conversion and save to file. Check `tts_with_vc` for more details. @@ -479,3 +481,4 @@ class TTS(nn.Module): text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences ) save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate) + return file_path