mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev' into main
This commit is contained in:
commit
33b5e87b56
21
README.md
21
README.md
|
@ -187,18 +187,21 @@ More details about the docker images (like GPU support) can be found [here](http
|
|||
|
||||
### 🐍 Python API
|
||||
|
||||
#### Running a multi-speaker and multi-lingual model
|
||||
|
||||
```python
|
||||
import torch
|
||||
from TTS.api import TTS
|
||||
|
||||
# Running a multi-speaker and multi-lingual model
|
||||
# Get device
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# List available 🐸TTS models and choose the first one
|
||||
model_name = TTS.list_models()[0]
|
||||
model_name = TTS().list_models()[0]
|
||||
# Init TTS
|
||||
tts = TTS(model_name)
|
||||
tts = TTS(model_name).to(device)
|
||||
|
||||
# Run TTS
|
||||
|
||||
# ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language
|
||||
# Text to speech with a numpy output
|
||||
wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
|
||||
|
@ -210,13 +213,13 @@ tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.langu
|
|||
|
||||
```python
|
||||
# Init TTS with the target model name
|
||||
tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
|
||||
tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False).to(device)
|
||||
|
||||
# Run TTS
|
||||
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
|
||||
|
||||
# Example voice cloning with YourTTS in English, French and Portuguese
|
||||
|
||||
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
|
||||
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False).to(device)
|
||||
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
|
||||
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr-fr", file_path="output.wav")
|
||||
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt-br", file_path="output.wav")
|
||||
|
@ -227,7 +230,7 @@ tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav",
|
|||
Converting the voice in `source_wav` to the voice of `target_wav`
|
||||
|
||||
```python
|
||||
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
|
||||
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False).to("cuda")
|
||||
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
|
||||
```
|
||||
|
||||
|
@ -256,7 +259,7 @@ These models will follow the naming convention `coqui_studio/en/<studio_speaker_
|
|||
# XTTS model
|
||||
models = TTS(cs_api_model="XTTS").list_models()
|
||||
# Init TTS with the target studio speaker
|
||||
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_bar=False, gpu=False)
|
||||
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_bar=False)
|
||||
# Run TTS
|
||||
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH)
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.16.5
|
||||
0.16.6
|
||||
|
|
|
@ -169,6 +169,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
help="Output wav file path.",
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False)
|
||||
parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu")
|
||||
parser.add_argument(
|
||||
"--vocoder_path",
|
||||
type=str,
|
||||
|
@ -391,6 +392,10 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
if args.encoder_path is not None:
|
||||
encoder_path = args.encoder_path
|
||||
encoder_config_path = args.encoder_config_path
|
||||
|
||||
device = args.device
|
||||
if args.use_cuda:
|
||||
device = "cuda"
|
||||
|
||||
# load models
|
||||
synthesizer = Synthesizer(
|
||||
|
@ -406,8 +411,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
vc_config_path,
|
||||
model_dir,
|
||||
args.voice_dir,
|
||||
args.use_cuda,
|
||||
)
|
||||
).to(device)
|
||||
|
||||
# query speaker ids of a multi-speaker model.
|
||||
if args.list_speaker_idxs:
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
import os
|
||||
|
||||
finder = None
|
||||
|
||||
|
||||
def init():
|
||||
try:
|
||||
import jpype
|
||||
import jpype.imports
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`.")
|
||||
|
||||
try:
|
||||
jar_path = os.environ["BEL_FANETYKA_JAR"]
|
||||
except KeyError:
|
||||
raise KeyError("You need to define 'BEL_FANETYKA_JAR' environment variable as path to the fanetyka.jar file")
|
||||
|
||||
jpype.startJVM(classpath=[jar_path])
|
||||
|
||||
# import the Java modules
|
||||
from org.alex73.korpus.base import GrammarDB2, GrammarFinder
|
||||
|
||||
grammar_db = GrammarDB2.initializeFromJar()
|
||||
global finder
|
||||
finder = GrammarFinder(grammar_db)
|
||||
|
||||
|
||||
def belarusian_text_to_phonemes(text: str) -> str:
|
||||
# Initialize only on first run
|
||||
if finder is None:
|
||||
init()
|
||||
|
||||
from org.alex73.fanetyka.impl import FanetykaText
|
||||
return str(FanetykaText(finder, text).ipa)
|
|
@ -1,4 +1,5 @@
|
|||
from TTS.tts.utils.text.phonemizers.bangla_phonemizer import BN_Phonemizer
|
||||
from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer
|
||||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak
|
||||
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
|
||||
|
@ -35,6 +36,7 @@ DEF_LANG_TO_PHONEMIZER["en"] = DEF_LANG_TO_PHONEMIZER["en-us"]
|
|||
DEF_LANG_TO_PHONEMIZER["zh-cn"] = ZH_CN_Phonemizer.name()
|
||||
DEF_LANG_TO_PHONEMIZER["ko-kr"] = KO_KR_Phonemizer.name()
|
||||
DEF_LANG_TO_PHONEMIZER["bn"] = BN_Phonemizer.name()
|
||||
DEF_LANG_TO_PHONEMIZER["be"] = BEL_Phonemizer.name()
|
||||
|
||||
|
||||
# JA phonemizer has deal breaking dependencies like MeCab for some systems.
|
||||
|
@ -68,6 +70,8 @@ def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer:
|
|||
return KO_KR_Phonemizer(**kwargs)
|
||||
if name == "bn_phonemizer":
|
||||
return BN_Phonemizer(**kwargs)
|
||||
if name == "be_phonemizer":
|
||||
return BEL_Phonemizer(**kwargs)
|
||||
raise ValueError(f"Phonemizer {name} not found")
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
from typing import Dict
|
||||
|
||||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes
|
||||
|
||||
_DEF_BE_PUNCS = ",!." # TODO
|
||||
|
||||
|
||||
class BEL_Phonemizer(BasePhonemizer):
|
||||
"""🐸TTS be phonemizer using functions in `TTS.tts.utils.text.belarusian.phonemizer`
|
||||
|
||||
Args:
|
||||
punctuations (str):
|
||||
Set of characters to be treated as punctuation. Defaults to `_DEF_BE_PUNCS`.
|
||||
|
||||
keep_puncs (bool):
|
||||
If True, keep the punctuations after phonemization. Defaults to False.
|
||||
"""
|
||||
|
||||
language = "be"
|
||||
|
||||
def __init__(self, punctuations=_DEF_BE_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument
|
||||
super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs)
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return "be_phonemizer"
|
||||
|
||||
@staticmethod
|
||||
def phonemize_be(text: str, separator: str = "|") -> str: # pylint: disable=unused-argument
|
||||
return belarusian_text_to_phonemes(text)
|
||||
|
||||
def _phonemize(self, text, separator):
|
||||
return self.phonemize_be(text, separator)
|
||||
|
||||
@staticmethod
|
||||
def supported_languages() -> Dict:
|
||||
return {"be": "Belarusian"}
|
||||
|
||||
def version(self) -> str:
|
||||
return "0.0.1"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
txt = "тэст"
|
||||
e = BEL_Phonemizer()
|
||||
print(e.supported_languages())
|
||||
print(e.version())
|
||||
print(e.language)
|
||||
print(e.name())
|
||||
print(e.is_available())
|
||||
print("`" + e.phonemize(txt) + "`")
|
|
@ -1,6 +1,6 @@
|
|||
furo
|
||||
myst-parser == 2.0.0
|
||||
sphinx == 7.0.1
|
||||
sphinx == 7.2.5
|
||||
sphinx_inline_tabs
|
||||
sphinx_copybutton
|
||||
linkify-it-py
|
|
@ -43,7 +43,7 @@ Start the container and get a shell inside it.
|
|||
```bash
|
||||
docker run --rm -it -p 5002:5002 --entrypoint /bin/bash ghcr.io/coqui-ai/tts-cpu
|
||||
python3 TTS/server/server.py --list_models #To get the list of available models
|
||||
python3 TTS/server/server.py --model_name tts_models/en/vctk/vits
|
||||
python3 TTS/server/server.py --model_name tts_models/en/vctk/vits
|
||||
```
|
||||
|
||||
### GPU version
|
||||
|
|
|
@ -36,7 +36,7 @@
|
|||
There is also the `callback` interface by which you can manipulate both the model and the `Trainer` states. Callbacks give you
|
||||
an infinite flexibility to add custom behaviours for your model and training routines.
|
||||
|
||||
For more details, see {ref}`BaseTTS <Base TTS Model>` and :obj:`TTS.utils.callbacks`.
|
||||
For more details, see {ref}`BaseTTS <Base tts Model>` and :obj:`TTS.utils.callbacks`.
|
||||
|
||||
6. Optionally, define `MyModelArgs`.
|
||||
|
||||
|
@ -204,5 +204,3 @@ class MyModel(BaseTTS):
|
|||
pass
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ You can run a multi-speaker and multi-lingual model in Python as
|
|||
from TTS.api import TTS
|
||||
|
||||
# List available 🐸TTS models and choose the first one
|
||||
model_name = TTS.list_models()[0]
|
||||
model_name = TTS().list_models()[0]
|
||||
# Init TTS
|
||||
tts = TTS(model_name)
|
||||
# Run TTS
|
||||
|
@ -132,7 +132,7 @@ tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.langu
|
|||
|
||||
```python
|
||||
# Init TTS with the target model name
|
||||
tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
|
||||
tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False)
|
||||
# Run TTS
|
||||
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
|
||||
```
|
||||
|
@ -140,7 +140,7 @@ tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
|
|||
#### Example voice cloning with YourTTS in English, French and Portuguese:
|
||||
|
||||
```python
|
||||
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
|
||||
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False).to("cuda")
|
||||
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
|
||||
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="output.wav")
|
||||
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
|
||||
|
@ -149,7 +149,7 @@ tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav",
|
|||
#### Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav`
|
||||
|
||||
```python
|
||||
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
|
||||
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False).to("cuda")
|
||||
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
|
||||
```
|
||||
|
||||
|
@ -177,7 +177,7 @@ You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API toke
|
|||
# The name format is coqui_studio/en/<studio_speaker_name>/coqui_studio
|
||||
models = TTS().list_models()
|
||||
# Init TTS with the target studio speaker
|
||||
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_bar=False, gpu=False)
|
||||
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_bar=False)
|
||||
# Run TTS
|
||||
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH)
|
||||
# Run TTS with emotion and speed control
|
||||
|
@ -222,7 +222,7 @@ You can find the list of language ISO codes [here](https://dl.fbaipublicfiles.co
|
|||
|
||||
```python
|
||||
from TTS.api import TTS
|
||||
api = TTS(model_name="tts_models/eng/fairseq/vits", gpu=True)
|
||||
api = TTS(model_name="tts_models/eng/fairseq/vits").to("cuda")
|
||||
api.tts_to_file("This is a test.", file_path="output.wav")
|
||||
|
||||
# TTS with on the fly voice conversion
|
||||
|
|
|
@ -5,18 +5,18 @@ Model API provides you a set of functions that easily make your model compatible
|
|||
## Base TTS Model
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.model.BaseModel
|
||||
.. autoclass:: TTS.model.BaseTrainerModel
|
||||
:members:
|
||||
```
|
||||
|
||||
## Base `tts` Model
|
||||
## Base tts Model
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.base_tts.BaseTTS
|
||||
:members:
|
||||
```
|
||||
|
||||
## Base `vocoder` Model
|
||||
## Base vocoder Model
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.vocoder.models.base_vocoder.BaseVocoder
|
||||
|
|
|
@ -91,12 +91,6 @@ tts --model_name tts_models/multilingual/multi-dataset/bark \
|
|||
:members:
|
||||
```
|
||||
|
||||
## BarkArgs
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.bark.BarkArgs
|
||||
:members:
|
||||
```
|
||||
|
||||
## Bark Model
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.bark.Bark
|
||||
|
|
|
@ -60,7 +60,7 @@ config = GlowTTSConfig(
|
|||
output_path=output_path,
|
||||
add_blank=True,
|
||||
datasets=[dataset_config],
|
||||
characters=characters,
|
||||
# characters=characters,
|
||||
enable_eos_bos_chars=True,
|
||||
mixed_precision=False,
|
||||
save_step=10000,
|
||||
|
@ -69,6 +69,8 @@ config = GlowTTSConfig(
|
|||
text_cleaner="no_cleaners",
|
||||
audio=audio_config,
|
||||
test_sentences=[],
|
||||
use_phonemes=True,
|
||||
phoneme_language="be",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
numpy==1.22.0;python_version<="3.10"
|
||||
numpy==1.24.3;python_version>"3.10"
|
||||
cython==0.29.30
|
||||
scipy>=1.4.0
|
||||
scipy>=1.11.2
|
||||
torch>=1.7
|
||||
torchaudio
|
||||
soundfile
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
import os
|
||||
import warnings
|
||||
import unittest
|
||||
|
||||
from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes
|
||||
|
||||
_TEST_CASES = """
|
||||
Фанетычны канвертар/fanʲɛˈtɨt͡ʂnɨ kanˈvʲɛrtar
|
||||
Гэтак мы працавалі/ˈɣɛtak ˈmɨ prat͡saˈvalʲi
|
||||
"""
|
||||
|
||||
|
||||
class TestText(unittest.TestCase):
|
||||
def test_belarusian_text_to_phonemes(self):
|
||||
try:
|
||||
os.environ["BEL_FANETYKA_JAR"]
|
||||
except KeyError:
|
||||
warnings.warn(
|
||||
"You need to define 'BEL_FANETYKA_JAR' environment variable as path to the fanetyka.jar file to test Belarusian phonemizer",
|
||||
Warning)
|
||||
return
|
||||
|
||||
for line in _TEST_CASES.strip().split("\n"):
|
||||
text, phonemes = line.split("/")
|
||||
self.assertEqual(belarusian_text_to_phonemes(text), phonemes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue