mirror of https://github.com/coqui-ai/TTS.git
Add sentence splitting (#3227)
* Add sentence spliting * update requirements * update default args v2 * Add spanish * Fix return gpt_latents * Update requirements * Fix requirements
This commit is contained in:
parent
3c2d5a9e03
commit
675f983550
|
@ -1,10 +1,10 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from functools import cached_property
|
|
||||||
|
|
||||||
import pypinyin
|
|
||||||
import torch
|
import torch
|
||||||
|
import pypinyin
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
from functools import cached_property
|
||||||
from hangul_romanize import Transliter
|
from hangul_romanize import Transliter
|
||||||
from hangul_romanize.rule import academic
|
from hangul_romanize.rule import academic
|
||||||
from num2words import num2words
|
from num2words import num2words
|
||||||
|
@ -12,6 +12,61 @@ from tokenizers import Tokenizer
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||||
|
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.lang.zh import Chinese
|
||||||
|
from spacy.lang.ja import Japanese
|
||||||
|
from spacy.lang.ar import Arabic
|
||||||
|
from spacy.lang.es import Spanish
|
||||||
|
|
||||||
|
|
||||||
|
def get_spacy_lang(lang):
|
||||||
|
if lang == "zh":
|
||||||
|
return Chinese()
|
||||||
|
elif lang == "ja":
|
||||||
|
return Japanese()
|
||||||
|
elif lang == "ar":
|
||||||
|
return Arabic()
|
||||||
|
elif lang == "es":
|
||||||
|
return Spanish()
|
||||||
|
else:
|
||||||
|
# For most languages, Enlish does the job
|
||||||
|
return English()
|
||||||
|
|
||||||
|
def split_sentence(text, lang, text_split_length=250):
|
||||||
|
"""Preprocess the input text"""
|
||||||
|
text_splits = []
|
||||||
|
if text_split_length is not None and len(text) >= text_split_length:
|
||||||
|
text_splits.append("")
|
||||||
|
nlp = get_spacy_lang(lang)
|
||||||
|
nlp.add_pipe("sentencizer")
|
||||||
|
doc = nlp(text)
|
||||||
|
for sentence in doc.sents:
|
||||||
|
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
|
||||||
|
# if the last sentence + the current sentence is less than the text_split_length
|
||||||
|
# then add the current sentence to the last sentence
|
||||||
|
text_splits[-1] += " " + str(sentence)
|
||||||
|
text_splits[-1] = text_splits[-1].lstrip()
|
||||||
|
elif len(str(sentence)) > text_split_length:
|
||||||
|
# if the current sentence is greater than the text_split_length
|
||||||
|
for line in textwrap.wrap(
|
||||||
|
str(sentence),
|
||||||
|
width=text_split_length,
|
||||||
|
drop_whitespace=True,
|
||||||
|
break_on_hyphens=False,
|
||||||
|
tabsize=1,
|
||||||
|
):
|
||||||
|
text_splits.append(str(line))
|
||||||
|
else:
|
||||||
|
text_splits.append(str(sentence))
|
||||||
|
|
||||||
|
if len(text_splits) > 1:
|
||||||
|
if text_splits[0] == "":
|
||||||
|
del text_splits[0]
|
||||||
|
else:
|
||||||
|
text_splits = [text.lstrip()]
|
||||||
|
|
||||||
|
return text_splits
|
||||||
|
|
||||||
_whitespace_re = re.compile(r"\s+")
|
_whitespace_re = re.compile(r"\s+")
|
||||||
|
|
||||||
# List of (regular expression, replacement) pairs for abbreviations:
|
# List of (regular expression, replacement) pairs for abbreviations:
|
||||||
|
@ -464,7 +519,7 @@ def _expand_number(m, lang="en"):
|
||||||
|
|
||||||
|
|
||||||
def expand_numbers_multilingual(text, lang="en"):
|
def expand_numbers_multilingual(text, lang="en"):
|
||||||
if lang == "zh" or lang == "zh-cn":
|
if lang == "zh":
|
||||||
text = zh_num2words()(text)
|
text = zh_num2words()(text)
|
||||||
else:
|
else:
|
||||||
if lang in ["en", "ru"]:
|
if lang in ["en", "ru"]:
|
||||||
|
@ -525,7 +580,7 @@ def japanese_cleaners(text, katsu):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def korean_cleaners(text):
|
def korean_transliterate(text):
|
||||||
r = Transliter(academic)
|
r = Transliter(academic)
|
||||||
return r.translit(text)
|
return r.translit(text)
|
||||||
|
|
||||||
|
@ -546,7 +601,7 @@ class VoiceBpeTokenizer:
|
||||||
"it": 213,
|
"it": 213,
|
||||||
"pt": 203,
|
"pt": 203,
|
||||||
"pl": 224,
|
"pl": 224,
|
||||||
"zh-cn": 82,
|
"zh": 82,
|
||||||
"ar": 166,
|
"ar": 166,
|
||||||
"cs": 186,
|
"cs": 186,
|
||||||
"ru": 182,
|
"ru": 182,
|
||||||
|
@ -571,19 +626,20 @@ class VoiceBpeTokenizer:
|
||||||
)
|
)
|
||||||
|
|
||||||
def preprocess_text(self, txt, lang):
|
def preprocess_text(self, txt, lang):
|
||||||
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
|
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "ko"}:
|
||||||
txt = multilingual_cleaners(txt, lang)
|
txt = multilingual_cleaners(txt, lang)
|
||||||
if lang in {"zh", "zh-cn"}:
|
if lang == "zh":
|
||||||
txt = chinese_transliterate(txt)
|
txt = chinese_transliterate(txt)
|
||||||
|
if lang == "ko":
|
||||||
|
txt = korean_transliterate(txt)
|
||||||
elif lang == "ja":
|
elif lang == "ja":
|
||||||
txt = japanese_cleaners(txt, self.katsu)
|
txt = japanese_cleaners(txt, self.katsu)
|
||||||
elif lang == "ko":
|
|
||||||
txt = korean_cleaners(txt)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Language '{lang}' is not supported.")
|
raise NotImplementedError(f"Language '{lang}' is not supported.")
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
def encode(self, txt, lang):
|
def encode(self, txt, lang):
|
||||||
|
lang = lang.split("-")[0] # remove the region
|
||||||
self.check_input_length(txt, lang)
|
self.check_input_length(txt, lang)
|
||||||
txt = self.preprocess_text(txt, lang)
|
txt = self.preprocess_text(txt, lang)
|
||||||
txt = f"[{lang}]{txt}"
|
txt = f"[{lang}]{txt}"
|
||||||
|
|
|
@ -10,7 +10,7 @@ from coqpit import Coqpit
|
||||||
from TTS.tts.layers.xtts.gpt import GPT
|
from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
@ -420,9 +420,9 @@ class Xtts(BaseTTS):
|
||||||
ref_audio_path,
|
ref_audio_path,
|
||||||
language,
|
language,
|
||||||
# GPT inference
|
# GPT inference
|
||||||
temperature=0.65,
|
temperature=0.75,
|
||||||
length_penalty=1,
|
length_penalty=1.0,
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=10.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
|
@ -502,24 +502,30 @@ class Xtts(BaseTTS):
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
# GPT inference
|
# GPT inference
|
||||||
temperature=0.65,
|
temperature=0.75,
|
||||||
length_penalty=1,
|
length_penalty=1.0,
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=10.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
speed=1.0,
|
speed=1.0,
|
||||||
|
enable_text_splitting=False,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
|
language = language.split("-")[0] # remove the country code
|
||||||
length_scale = 1.0 / max(speed, 0.05)
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
text = text.strip().lower()
|
if enable_text_splitting:
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||||
|
else:
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
wavs = []
|
||||||
|
gpt_latents_list = []
|
||||||
|
for sent in text:
|
||||||
|
sent = sent.strip().lower()
|
||||||
|
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
# print(" > Input text: ", text)
|
|
||||||
# print(" > Input text preprocessed: ",self.tokenizer.preprocess_text(text, language))
|
|
||||||
# print(" > Input tokens: ", text_tokens)
|
|
||||||
# print(" > Decoded text: ", self.tokenizer.decode(text_tokens[0].cpu().numpy()))
|
|
||||||
assert (
|
assert (
|
||||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||||
|
@ -562,11 +568,12 @@ class Xtts(BaseTTS):
|
||||||
mode="linear"
|
mode="linear"
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
|
|
||||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
gpt_latents_list.append(gpt_latents.cpu())
|
||||||
|
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"wav": wav.cpu().numpy().squeeze(),
|
"wav": torch.cat(wavs, dim=0).numpy(),
|
||||||
"gpt_latents": gpt_latents,
|
"gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(),
|
||||||
"speaker_embedding": speaker_embedding,
|
"speaker_embedding": speaker_embedding,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -606,18 +613,30 @@ class Xtts(BaseTTS):
|
||||||
stream_chunk_size=20,
|
stream_chunk_size=20,
|
||||||
overlap_wav_len=1024,
|
overlap_wav_len=1024,
|
||||||
# GPT inference
|
# GPT inference
|
||||||
temperature=0.65,
|
temperature=0.75,
|
||||||
length_penalty=1,
|
length_penalty=1.0,
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=10.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
speed=1.0,
|
speed=1.0,
|
||||||
|
enable_text_splitting=False,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
|
language = language.split("-")[0] # remove the country code
|
||||||
length_scale = 1.0 / max(speed, 0.05)
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
text = text.strip().lower()
|
if enable_text_splitting:
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||||
|
else:
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
for sent in text:
|
||||||
|
sent = sent.strip().lower()
|
||||||
|
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||||
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||||
|
|
||||||
fake_inputs = self.gpt.compute_embeddings(
|
fake_inputs = self.gpt.compute_embeddings(
|
||||||
gpt_cond_latent.to(self.device),
|
gpt_cond_latent.to(self.device),
|
||||||
|
|
|
@ -54,3 +54,4 @@ encodec==0.1.*
|
||||||
# deps for XTTS
|
# deps for XTTS
|
||||||
unidecode==1.3.*
|
unidecode==1.3.*
|
||||||
num2words
|
num2words
|
||||||
|
spacy[ja]>=3
|
Loading…
Reference in New Issue