diff --git a/TTS/api.py b/TTS/api.py index 99c3e522..850f0681 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -7,7 +7,16 @@ from TTS.utils.synthesizer import Synthesizer class TTS: """TODO: Add voice conversion and Capacitron support.""" - def __init__(self, model_name: str = None, progress_bar: bool = True, gpu=False): + def __init__( + self, + model_name: str = None, + model_path: str = None, + config_path: str = None, + vocoder_path: str = None, + vocoder_config_path: str = None, + progress_bar: bool = True, + gpu=False, + ): """🐸TTS python interface that allows to load and use the released models. Example with a multi-speaker model: @@ -20,8 +29,22 @@ class TTS: >>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False) >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") + Example loading a model from a path: + >>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False) + >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") + + Example voice cloning with YourTTS in English, French and Portuguese: + >>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True) + >>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav") + >>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav") + >>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav") + Args: model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. + model_path (str, optional): Path to the model checkpoint. Defaults to None. + config_path (str, optional): Path to the model config. Defaults to None. + vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. + vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None. progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True. gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ @@ -29,6 +52,10 @@ class TTS: self.synthesizer = None if model_name: self.load_model_by_name(model_name, gpu) + if model_path: + self.load_model_by_path( + model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu + ) @property def models(self): @@ -75,7 +102,17 @@ class TTS: return model_path, config_path, vocoder_path, vocoder_config_path def load_model_by_name(self, model_name: str, gpu: bool = False): + """ Load one of 🐸TTS models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + + TODO: Add tests + """ + model_path, config_path, vocoder_path, vocoder_config_path = self.download_model_by_name(model_name) + # init synthesizer # None values are fetch from the model self.synthesizer = Synthesizer( @@ -90,8 +127,33 @@ class TTS: use_cuda=gpu, ) - def _check_arguments(self, speaker: str = None, language: str = None): - if self.is_multi_speaker and speaker is None: + def load_model_by_path( + self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False + ): + """Load a model from a path. + + Args: + model_path (str): Path to the model checkpoint. + config_path (str): Path to the model config. + vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. + vocoder_config (str, optional): Path to the vocoder config. Defaults to None. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + + self.synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=None, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config, + encoder_checkpoint=None, + encoder_config=None, + use_cuda=gpu, + ) + + def _check_arguments(self, speaker: str = None, language: str = None, speaker_wav: str = None): + if self.is_multi_speaker and (speaker is None and speaker_wav is None): raise ValueError("Model is multi-speaker but no speaker is provided.") if self.is_multi_lingual and language is None: raise ValueError("Model is multi-lingual but no language is provided.") @@ -100,7 +162,7 @@ class TTS: if not self.is_multi_lingual and language is not None: raise ValueError("Model is not multi-lingual but language is provided.") - def tts(self, text: str, speaker: str = None, language: str = None): + def tts(self, text: str, speaker: str = None, language: str = None, speaker_wav: str = None): """Convert text to speech. Args: @@ -112,14 +174,17 @@ class TTS: language (str, optional): Language code for multi-lingual models. You can check whether loaded model is multi-lingual `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. """ - self._check_arguments(speaker=speaker, language=language) + self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav) wav = self.synthesizer.tts( text=text, speaker_name=speaker, language_name=language, - speaker_wav=None, + speaker_wav=speaker_wav, reference_wav=None, style_wav=None, style_text=None, @@ -127,7 +192,14 @@ class TTS: ) return wav - def tts_to_file(self, text: str, speaker: str = None, language: str = None, file_path: str = "output.wav"): + def tts_to_file( + self, + text: str, + speaker: str = None, + language: str = None, + speaker_wav: str = None, + file_path: str = "output.wav", + ): """Convert text to speech. Args: @@ -139,8 +211,11 @@ class TTS: language (str, optional): Language code for multi-lingual models. You can check whether loaded model is multi-lingual `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. file_path (str, optional): Output file path. Defaults to "output.wav". """ - wav = self.tts(text=text, speaker=speaker, language=language) + wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav) self.synthesizer.save_wav(wav=wav, path=file_path) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 2cef8d70..2310baf9 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -187,7 +187,7 @@ class Synthesizer(object): text (str): input text. speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "". language_name (str, optional): language id for multi-language models. Defaults to "". - speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None. + speaker_wav (Union[str, List[str]], optional): path to the speaker wav for voice cloning. Defaults to None. style_wav ([type], optional): style waveform for GST. Defaults to None. style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None. reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None. diff --git a/tests/inference_tests/test_python_api.py b/tests/inference_tests/test_python_api.py index fdd7e1cb..71690440 100644 --- a/tests/inference_tests/test_python_api.py +++ b/tests/inference_tests/test_python_api.py @@ -1,10 +1,12 @@ import os import unittest -from tests import get_tests_output_path +from tests import get_tests_data_path, get_tests_output_path + from TTS.api import TTS OUTPUT_PATH = os.path.join(get_tests_output_path(), "test_python_api.wav") +cloning_test_wav_path = os.path.join(get_tests_data_path(), "ljspeech/wavs/LJ001-0028.wav") class TTSTest(unittest.TestCase): @@ -34,3 +36,9 @@ class TTSTest(unittest.TestCase): self.assertTrue(tts.is_multi_lingual) self.assertGreater(len(tts.speakers), 1) self.assertGreater(len(tts.languages), 1) + + @staticmethod + def test_voice_cloning(self): + tts = TTS() + tts.load_model_by_name("tts_models/multilingual/multi-dataset/your_tts") + tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH)