mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'api_model_path' into dev
This commit is contained in:
commit
85b3a04b37
91
TTS/api.py
91
TTS/api.py
|
@ -7,7 +7,16 @@ from TTS.utils.synthesizer import Synthesizer
|
|||
class TTS:
|
||||
"""TODO: Add voice conversion and Capacitron support."""
|
||||
|
||||
def __init__(self, model_name: str = None, progress_bar: bool = True, gpu=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
model_path: str = None,
|
||||
config_path: str = None,
|
||||
vocoder_path: str = None,
|
||||
vocoder_config_path: str = None,
|
||||
progress_bar: bool = True,
|
||||
gpu=False,
|
||||
):
|
||||
"""🐸TTS python interface that allows to load and use the released models.
|
||||
|
||||
Example with a multi-speaker model:
|
||||
|
@ -20,8 +29,22 @@ class TTS:
|
|||
>>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
|
||||
>>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav")
|
||||
|
||||
Example loading a model from a path:
|
||||
>>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False)
|
||||
>>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav")
|
||||
|
||||
Example voice cloning with YourTTS in English, French and Portuguese:
|
||||
>>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
|
||||
>>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav")
|
||||
>>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
|
||||
>>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
|
||||
|
||||
Args:
|
||||
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
|
||||
model_path (str, optional): Path to the model checkpoint. Defaults to None.
|
||||
config_path (str, optional): Path to the model config. Defaults to None.
|
||||
vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
|
||||
vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
|
||||
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
|
||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
|
@ -29,6 +52,10 @@ class TTS:
|
|||
self.synthesizer = None
|
||||
if model_name:
|
||||
self.load_model_by_name(model_name, gpu)
|
||||
if model_path:
|
||||
self.load_model_by_path(
|
||||
model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
|
||||
)
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
|
@ -75,7 +102,17 @@ class TTS:
|
|||
return model_path, config_path, vocoder_path, vocoder_config_path
|
||||
|
||||
def load_model_by_name(self, model_name: str, gpu: bool = False):
|
||||
""" Load one of 🐸TTS models by name.
|
||||
|
||||
Args:
|
||||
model_name (str): Model name to load. You can list models by ```tts.models```.
|
||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
|
||||
TODO: Add tests
|
||||
"""
|
||||
|
||||
model_path, config_path, vocoder_path, vocoder_config_path = self.download_model_by_name(model_name)
|
||||
|
||||
# init synthesizer
|
||||
# None values are fetch from the model
|
||||
self.synthesizer = Synthesizer(
|
||||
|
@ -90,8 +127,33 @@ class TTS:
|
|||
use_cuda=gpu,
|
||||
)
|
||||
|
||||
def _check_arguments(self, speaker: str = None, language: str = None):
|
||||
if self.is_multi_speaker and speaker is None:
|
||||
def load_model_by_path(
|
||||
self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
|
||||
):
|
||||
"""Load a model from a path.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the model checkpoint.
|
||||
config_path (str): Path to the model config.
|
||||
vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
|
||||
vocoder_config (str, optional): Path to the vocoder config. Defaults to None.
|
||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
|
||||
self.synthesizer = Synthesizer(
|
||||
tts_checkpoint=model_path,
|
||||
tts_config_path=config_path,
|
||||
tts_speakers_file=None,
|
||||
tts_languages_file=None,
|
||||
vocoder_checkpoint=vocoder_path,
|
||||
vocoder_config=vocoder_config,
|
||||
encoder_checkpoint=None,
|
||||
encoder_config=None,
|
||||
use_cuda=gpu,
|
||||
)
|
||||
|
||||
def _check_arguments(self, speaker: str = None, language: str = None, speaker_wav: str = None):
|
||||
if self.is_multi_speaker and (speaker is None and speaker_wav is None):
|
||||
raise ValueError("Model is multi-speaker but no speaker is provided.")
|
||||
if self.is_multi_lingual and language is None:
|
||||
raise ValueError("Model is multi-lingual but no language is provided.")
|
||||
|
@ -100,7 +162,7 @@ class TTS:
|
|||
if not self.is_multi_lingual and language is not None:
|
||||
raise ValueError("Model is not multi-lingual but language is provided.")
|
||||
|
||||
def tts(self, text: str, speaker: str = None, language: str = None):
|
||||
def tts(self, text: str, speaker: str = None, language: str = None, speaker_wav: str = None):
|
||||
"""Convert text to speech.
|
||||
|
||||
Args:
|
||||
|
@ -112,14 +174,17 @@ class TTS:
|
|||
language (str, optional):
|
||||
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
|
||||
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
|
||||
speaker_wav (str, optional):
|
||||
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._check_arguments(speaker=speaker, language=language)
|
||||
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav)
|
||||
|
||||
wav = self.synthesizer.tts(
|
||||
text=text,
|
||||
speaker_name=speaker,
|
||||
language_name=language,
|
||||
speaker_wav=None,
|
||||
speaker_wav=speaker_wav,
|
||||
reference_wav=None,
|
||||
style_wav=None,
|
||||
style_text=None,
|
||||
|
@ -127,7 +192,14 @@ class TTS:
|
|||
)
|
||||
return wav
|
||||
|
||||
def tts_to_file(self, text: str, speaker: str = None, language: str = None, file_path: str = "output.wav"):
|
||||
def tts_to_file(
|
||||
self,
|
||||
text: str,
|
||||
speaker: str = None,
|
||||
language: str = None,
|
||||
speaker_wav: str = None,
|
||||
file_path: str = "output.wav",
|
||||
):
|
||||
"""Convert text to speech.
|
||||
|
||||
Args:
|
||||
|
@ -139,8 +211,11 @@ class TTS:
|
|||
language (str, optional):
|
||||
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
|
||||
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
|
||||
speaker_wav (str, optional):
|
||||
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
||||
Defaults to None.
|
||||
file_path (str, optional):
|
||||
Output file path. Defaults to "output.wav".
|
||||
"""
|
||||
wav = self.tts(text=text, speaker=speaker, language=language)
|
||||
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav)
|
||||
self.synthesizer.save_wav(wav=wav, path=file_path)
|
||||
|
|
|
@ -187,7 +187,7 @@ class Synthesizer(object):
|
|||
text (str): input text.
|
||||
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
||||
language_name (str, optional): language id for multi-language models. Defaults to "".
|
||||
speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
|
||||
speaker_wav (Union[str, List[str]], optional): path to the speaker wav for voice cloning. Defaults to None.
|
||||
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
||||
style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None.
|
||||
reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None.
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
from tests import get_tests_output_path
|
||||
from tests import get_tests_data_path, get_tests_output_path
|
||||
|
||||
from TTS.api import TTS
|
||||
|
||||
OUTPUT_PATH = os.path.join(get_tests_output_path(), "test_python_api.wav")
|
||||
cloning_test_wav_path = os.path.join(get_tests_data_path(), "ljspeech/wavs/LJ001-0028.wav")
|
||||
|
||||
|
||||
class TTSTest(unittest.TestCase):
|
||||
|
@ -34,3 +36,9 @@ class TTSTest(unittest.TestCase):
|
|||
self.assertTrue(tts.is_multi_lingual)
|
||||
self.assertGreater(len(tts.speakers), 1)
|
||||
self.assertGreater(len(tts.languages), 1)
|
||||
|
||||
@staticmethod
|
||||
def test_voice_cloning(self):
|
||||
tts = TTS()
|
||||
tts.load_model_by_name("tts_models/multilingual/multi-dataset/your_tts")
|
||||
tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH)
|
||||
|
|
Loading…
Reference in New Issue