mirror of https://github.com/coqui-ai/TTS.git
Add sentence splitting (#3227)
* Add sentence spliting * update requirements * update default args v2 * Add spanish * Fix return gpt_latents * Update requirements * Fix requirements
This commit is contained in:
parent
3c2d5a9e03
commit
675f983550
|
@ -1,10 +1,10 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from functools import cached_property
|
||||
|
||||
import pypinyin
|
||||
import torch
|
||||
import pypinyin
|
||||
import textwrap
|
||||
|
||||
from functools import cached_property
|
||||
from hangul_romanize import Transliter
|
||||
from hangul_romanize.rule import academic
|
||||
from num2words import num2words
|
||||
|
@ -12,6 +12,61 @@ from tokenizers import Tokenizer
|
|||
|
||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.zh import Chinese
|
||||
from spacy.lang.ja import Japanese
|
||||
from spacy.lang.ar import Arabic
|
||||
from spacy.lang.es import Spanish
|
||||
|
||||
|
||||
def get_spacy_lang(lang):
|
||||
if lang == "zh":
|
||||
return Chinese()
|
||||
elif lang == "ja":
|
||||
return Japanese()
|
||||
elif lang == "ar":
|
||||
return Arabic()
|
||||
elif lang == "es":
|
||||
return Spanish()
|
||||
else:
|
||||
# For most languages, Enlish does the job
|
||||
return English()
|
||||
|
||||
def split_sentence(text, lang, text_split_length=250):
|
||||
"""Preprocess the input text"""
|
||||
text_splits = []
|
||||
if text_split_length is not None and len(text) >= text_split_length:
|
||||
text_splits.append("")
|
||||
nlp = get_spacy_lang(lang)
|
||||
nlp.add_pipe("sentencizer")
|
||||
doc = nlp(text)
|
||||
for sentence in doc.sents:
|
||||
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
|
||||
# if the last sentence + the current sentence is less than the text_split_length
|
||||
# then add the current sentence to the last sentence
|
||||
text_splits[-1] += " " + str(sentence)
|
||||
text_splits[-1] = text_splits[-1].lstrip()
|
||||
elif len(str(sentence)) > text_split_length:
|
||||
# if the current sentence is greater than the text_split_length
|
||||
for line in textwrap.wrap(
|
||||
str(sentence),
|
||||
width=text_split_length,
|
||||
drop_whitespace=True,
|
||||
break_on_hyphens=False,
|
||||
tabsize=1,
|
||||
):
|
||||
text_splits.append(str(line))
|
||||
else:
|
||||
text_splits.append(str(sentence))
|
||||
|
||||
if len(text_splits) > 1:
|
||||
if text_splits[0] == "":
|
||||
del text_splits[0]
|
||||
else:
|
||||
text_splits = [text.lstrip()]
|
||||
|
||||
return text_splits
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
|
@ -464,7 +519,7 @@ def _expand_number(m, lang="en"):
|
|||
|
||||
|
||||
def expand_numbers_multilingual(text, lang="en"):
|
||||
if lang == "zh" or lang == "zh-cn":
|
||||
if lang == "zh":
|
||||
text = zh_num2words()(text)
|
||||
else:
|
||||
if lang in ["en", "ru"]:
|
||||
|
@ -525,7 +580,7 @@ def japanese_cleaners(text, katsu):
|
|||
return text
|
||||
|
||||
|
||||
def korean_cleaners(text):
|
||||
def korean_transliterate(text):
|
||||
r = Transliter(academic)
|
||||
return r.translit(text)
|
||||
|
||||
|
@ -546,7 +601,7 @@ class VoiceBpeTokenizer:
|
|||
"it": 213,
|
||||
"pt": 203,
|
||||
"pl": 224,
|
||||
"zh-cn": 82,
|
||||
"zh": 82,
|
||||
"ar": 166,
|
||||
"cs": 186,
|
||||
"ru": 182,
|
||||
|
@ -571,19 +626,20 @@ class VoiceBpeTokenizer:
|
|||
)
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
|
||||
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "ko"}:
|
||||
txt = multilingual_cleaners(txt, lang)
|
||||
if lang in {"zh", "zh-cn"}:
|
||||
if lang == "zh":
|
||||
txt = chinese_transliterate(txt)
|
||||
if lang == "ko":
|
||||
txt = korean_transliterate(txt)
|
||||
elif lang == "ja":
|
||||
txt = japanese_cleaners(txt, self.katsu)
|
||||
elif lang == "ko":
|
||||
txt = korean_cleaners(txt)
|
||||
else:
|
||||
raise NotImplementedError(f"Language '{lang}' is not supported.")
|
||||
return txt
|
||||
|
||||
def encode(self, txt, lang):
|
||||
lang = lang.split("-")[0] # remove the region
|
||||
self.check_input_length(txt, lang)
|
||||
txt = self.preprocess_text(txt, lang)
|
||||
txt = f"[{lang}]{txt}"
|
||||
|
|
|
@ -10,7 +10,7 @@ from coqpit import Coqpit
|
|||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
@ -420,9 +420,9 @@ class Xtts(BaseTTS):
|
|||
ref_audio_path,
|
||||
language,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
temperature=0.75,
|
||||
length_penalty=1.0,
|
||||
repetition_penalty=10.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
|
@ -502,24 +502,30 @@ class Xtts(BaseTTS):
|
|||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
temperature=0.75,
|
||||
length_penalty=1.0,
|
||||
repetition_penalty=10.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
speed=1.0,
|
||||
enable_text_splitting=False,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
language = language.split("-")[0] # remove the country code
|
||||
length_scale = 1.0 / max(speed, 0.05)
|
||||
text = text.strip().lower()
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
if enable_text_splitting:
|
||||
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||
else:
|
||||
text = [text]
|
||||
|
||||
wavs = []
|
||||
gpt_latents_list = []
|
||||
for sent in text:
|
||||
sent = sent.strip().lower()
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
# print(" > Input text: ", text)
|
||||
# print(" > Input text preprocessed: ",self.tokenizer.preprocess_text(text, language))
|
||||
# print(" > Input tokens: ", text_tokens)
|
||||
# print(" > Decoded text: ", self.tokenizer.decode(text_tokens[0].cpu().numpy()))
|
||||
assert (
|
||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||
|
@ -562,11 +568,12 @@ class Xtts(BaseTTS):
|
|||
mode="linear"
|
||||
).transpose(1, 2)
|
||||
|
||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
gpt_latents_list.append(gpt_latents.cpu())
|
||||
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
|
||||
|
||||
return {
|
||||
"wav": wav.cpu().numpy().squeeze(),
|
||||
"gpt_latents": gpt_latents,
|
||||
"wav": torch.cat(wavs, dim=0).numpy(),
|
||||
"gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(),
|
||||
"speaker_embedding": speaker_embedding,
|
||||
}
|
||||
|
||||
|
@ -606,18 +613,30 @@ class Xtts(BaseTTS):
|
|||
stream_chunk_size=20,
|
||||
overlap_wav_len=1024,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
temperature=0.75,
|
||||
length_penalty=1.0,
|
||||
repetition_penalty=10.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
speed=1.0,
|
||||
enable_text_splitting=False,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
language = language.split("-")[0] # remove the country code
|
||||
length_scale = 1.0 / max(speed, 0.05)
|
||||
text = text.strip().lower()
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
if enable_text_splitting:
|
||||
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||
else:
|
||||
text = [text]
|
||||
|
||||
for sent in text:
|
||||
sent = sent.strip().lower()
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
assert (
|
||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||
|
||||
fake_inputs = self.gpt.compute_embeddings(
|
||||
gpt_cond_latent.to(self.device),
|
||||
|
|
|
@ -54,3 +54,4 @@ encodec==0.1.*
|
|||
# deps for XTTS
|
||||
unidecode==1.3.*
|
||||
num2words
|
||||
spacy[ja]>=3
|
Loading…
Reference in New Issue