From 7fddabc8ac9c84a9d05ba7928f454e45558c9422 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 30 Jan 2023 13:35:48 +0100 Subject: [PATCH] Implement cloning in API --- TTS/api.py | 33 +++++++++++++++++++----- TTS/utils/synthesizer.py | 2 +- tests/inference_tests/test_python_api.py | 9 ++++++- 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 22b81ba4..6fa8c606 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -33,6 +33,12 @@ class TTS: >>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False) >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") + Example voice cloning with YourTTS in English, French and Portuguese: + >>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True) + >>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav") + >>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav") + >>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav") + Args: model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. model_path (str, optional): Path to the model checkpoint. Defaults to None. @@ -144,8 +150,8 @@ class TTS: use_cuda=gpu, ) - def _check_arguments(self, speaker: str = None, language: str = None): - if self.is_multi_speaker and speaker is None: + def _check_arguments(self, speaker: str = None, language: str = None, speaker_wav: str = None): + if self.is_multi_speaker and (speaker is None and speaker_wav is None): raise ValueError("Model is multi-speaker but no speaker is provided.") if self.is_multi_lingual and language is None: raise ValueError("Model is multi-lingual but no language is provided.") @@ -154,7 +160,7 @@ class TTS: if not self.is_multi_lingual and language is not None: raise ValueError("Model is not multi-lingual but language is provided.") - def tts(self, text: str, speaker: str = None, language: str = None): + def tts(self, text: str, speaker: str = None, language: str = None, speaker_wav: str = None): """Convert text to speech. Args: @@ -166,14 +172,17 @@ class TTS: language (str, optional): Language code for multi-lingual models. You can check whether loaded model is multi-lingual `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. """ - self._check_arguments(speaker=speaker, language=language) + self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav) wav = self.synthesizer.tts( text=text, speaker_name=speaker, language_name=language, - speaker_wav=None, + speaker_wav=speaker_wav, reference_wav=None, style_wav=None, style_text=None, @@ -181,7 +190,14 @@ class TTS: ) return wav - def tts_to_file(self, text: str, speaker: str = None, language: str = None, file_path: str = "output.wav"): + def tts_to_file( + self, + text: str, + speaker: str = None, + language: str = None, + speaker_wav: str = None, + file_path: str = "output.wav", + ): """Convert text to speech. Args: @@ -193,8 +209,11 @@ class TTS: language (str, optional): Language code for multi-lingual models. You can check whether loaded model is multi-lingual `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. file_path (str, optional): Output file path. Defaults to "output.wav". """ - wav = self.tts(text=text, speaker=speaker, language=language) + wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav) self.synthesizer.save_wav(wav=wav, path=file_path) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 4a0ab038..498dc7ba 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -187,7 +187,7 @@ class Synthesizer(object): text (str): input text. speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "". language_name (str, optional): language id for multi-language models. Defaults to "". - speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None. + speaker_wav (Union[str, List[str]], optional): path to the speaker wav for voice cloning. Defaults to None. style_wav ([type], optional): style waveform for GST. Defaults to None. style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None. reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None. diff --git a/tests/inference_tests/test_python_api.py b/tests/inference_tests/test_python_api.py index fdd7e1cb..b306b5ea 100644 --- a/tests/inference_tests/test_python_api.py +++ b/tests/inference_tests/test_python_api.py @@ -1,10 +1,12 @@ import os import unittest -from tests import get_tests_output_path +from tests import get_tests_data_path, get_tests_output_path + from TTS.api import TTS OUTPUT_PATH = os.path.join(get_tests_output_path(), "test_python_api.wav") +cloning_test_wav_path = os.path.join(get_tests_data_path(), "ljspeech/wavs/LJ001-0028.wav") class TTSTest(unittest.TestCase): @@ -34,3 +36,8 @@ class TTSTest(unittest.TestCase): self.assertTrue(tts.is_multi_lingual) self.assertGreater(len(tts.speakers), 1) self.assertGreater(len(tts.languages), 1) + + def test_voice_cloning(): + tts = TTS() + tts.load_model_by_name("tts_models/multilingual/multi-dataset/your_tts") + tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH)