Add sentence splitting (#3227)

* Add sentence spliting

* update requirements

* update default args v2

* Add spanish

* Fix return gpt_latents

* Update requirements

* Fix requirements
This commit is contained in:
Julian Weber 2023-11-16 11:01:11 +01:00 committed by GitHub
parent 3c2d5a9e03
commit 675f983550
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 190 additions and 114 deletions

View File

@ -1,10 +1,10 @@
import json
import os import os
import re import re
from functools import cached_property
import pypinyin
import torch import torch
import pypinyin
import textwrap
from functools import cached_property
from hangul_romanize import Transliter from hangul_romanize import Transliter
from hangul_romanize.rule import academic from hangul_romanize.rule import academic
from num2words import num2words from num2words import num2words
@ -12,6 +12,61 @@ from tokenizers import Tokenizer
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
from spacy.lang.en import English
from spacy.lang.zh import Chinese
from spacy.lang.ja import Japanese
from spacy.lang.ar import Arabic
from spacy.lang.es import Spanish
def get_spacy_lang(lang):
if lang == "zh":
return Chinese()
elif lang == "ja":
return Japanese()
elif lang == "ar":
return Arabic()
elif lang == "es":
return Spanish()
else:
# For most languages, Enlish does the job
return English()
def split_sentence(text, lang, text_split_length=250):
"""Preprocess the input text"""
text_splits = []
if text_split_length is not None and len(text) >= text_split_length:
text_splits.append("")
nlp = get_spacy_lang(lang)
nlp.add_pipe("sentencizer")
doc = nlp(text)
for sentence in doc.sents:
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
# if the last sentence + the current sentence is less than the text_split_length
# then add the current sentence to the last sentence
text_splits[-1] += " " + str(sentence)
text_splits[-1] = text_splits[-1].lstrip()
elif len(str(sentence)) > text_split_length:
# if the current sentence is greater than the text_split_length
for line in textwrap.wrap(
str(sentence),
width=text_split_length,
drop_whitespace=True,
break_on_hyphens=False,
tabsize=1,
):
text_splits.append(str(line))
else:
text_splits.append(str(sentence))
if len(text_splits) > 1:
if text_splits[0] == "":
del text_splits[0]
else:
text_splits = [text.lstrip()]
return text_splits
_whitespace_re = re.compile(r"\s+") _whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations: # List of (regular expression, replacement) pairs for abbreviations:
@ -464,7 +519,7 @@ def _expand_number(m, lang="en"):
def expand_numbers_multilingual(text, lang="en"): def expand_numbers_multilingual(text, lang="en"):
if lang == "zh" or lang == "zh-cn": if lang == "zh":
text = zh_num2words()(text) text = zh_num2words()(text)
else: else:
if lang in ["en", "ru"]: if lang in ["en", "ru"]:
@ -525,7 +580,7 @@ def japanese_cleaners(text, katsu):
return text return text
def korean_cleaners(text): def korean_transliterate(text):
r = Transliter(academic) r = Transliter(academic)
return r.translit(text) return r.translit(text)
@ -546,7 +601,7 @@ class VoiceBpeTokenizer:
"it": 213, "it": 213,
"pt": 203, "pt": 203,
"pl": 224, "pl": 224,
"zh-cn": 82, "zh": 82,
"ar": 166, "ar": 166,
"cs": 186, "cs": 186,
"ru": 182, "ru": 182,
@ -571,19 +626,20 @@ class VoiceBpeTokenizer:
) )
def preprocess_text(self, txt, lang): def preprocess_text(self, txt, lang):
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}: if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "ko"}:
txt = multilingual_cleaners(txt, lang) txt = multilingual_cleaners(txt, lang)
if lang in {"zh", "zh-cn"}: if lang == "zh":
txt = chinese_transliterate(txt) txt = chinese_transliterate(txt)
if lang == "ko":
txt = korean_transliterate(txt)
elif lang == "ja": elif lang == "ja":
txt = japanese_cleaners(txt, self.katsu) txt = japanese_cleaners(txt, self.katsu)
elif lang == "ko":
txt = korean_cleaners(txt)
else: else:
raise NotImplementedError(f"Language '{lang}' is not supported.") raise NotImplementedError(f"Language '{lang}' is not supported.")
return txt return txt
def encode(self, txt, lang): def encode(self, txt, lang):
lang = lang.split("-")[0] # remove the region
self.check_input_length(txt, lang) self.check_input_length(txt, lang)
txt = self.preprocess_text(txt, lang) txt = self.preprocess_text(txt, lang)
txt = f"[{lang}]{txt}" txt = f"[{lang}]{txt}"

View File

@ -10,7 +10,7 @@ from coqpit import Coqpit
from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
@ -420,9 +420,9 @@ class Xtts(BaseTTS):
ref_audio_path, ref_audio_path,
language, language,
# GPT inference # GPT inference
temperature=0.65, temperature=0.75,
length_penalty=1, length_penalty=1.0,
repetition_penalty=2.0, repetition_penalty=10.0,
top_k=50, top_k=50,
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
@ -502,24 +502,30 @@ class Xtts(BaseTTS):
gpt_cond_latent, gpt_cond_latent,
speaker_embedding, speaker_embedding,
# GPT inference # GPT inference
temperature=0.65, temperature=0.75,
length_penalty=1, length_penalty=1.0,
repetition_penalty=2.0, repetition_penalty=10.0,
top_k=50, top_k=50,
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
num_beams=1, num_beams=1,
speed=1.0, speed=1.0,
enable_text_splitting=False,
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
language = language.split("-")[0] # remove the country code
length_scale = 1.0 / max(speed, 0.05) length_scale = 1.0 / max(speed, 0.05)
text = text.strip().lower() if enable_text_splitting:
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) text = split_sentence(text, language, self.tokenizer.char_limits[language])
else:
text = [text]
wavs = []
gpt_latents_list = []
for sent in text:
sent = sent.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
# print(" > Input text: ", text)
# print(" > Input text preprocessed: ",self.tokenizer.preprocess_text(text, language))
# print(" > Input tokens: ", text_tokens)
# print(" > Decoded text: ", self.tokenizer.decode(text_tokens[0].cpu().numpy()))
assert ( assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens." ), " ❗ XTTS can only generate text with a maximum of 400 tokens."
@ -562,11 +568,12 @@ class Xtts(BaseTTS):
mode="linear" mode="linear"
).transpose(1, 2) ).transpose(1, 2)
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) gpt_latents_list.append(gpt_latents.cpu())
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
return { return {
"wav": wav.cpu().numpy().squeeze(), "wav": torch.cat(wavs, dim=0).numpy(),
"gpt_latents": gpt_latents, "gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(),
"speaker_embedding": speaker_embedding, "speaker_embedding": speaker_embedding,
} }
@ -606,18 +613,30 @@ class Xtts(BaseTTS):
stream_chunk_size=20, stream_chunk_size=20,
overlap_wav_len=1024, overlap_wav_len=1024,
# GPT inference # GPT inference
temperature=0.65, temperature=0.75,
length_penalty=1, length_penalty=1.0,
repetition_penalty=2.0, repetition_penalty=10.0,
top_k=50, top_k=50,
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
speed=1.0, speed=1.0,
enable_text_splitting=False,
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
language = language.split("-")[0] # remove the country code
length_scale = 1.0 / max(speed, 0.05) length_scale = 1.0 / max(speed, 0.05)
text = text.strip().lower() if enable_text_splitting:
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) text = split_sentence(text, language, self.tokenizer.char_limits[language])
else:
text = [text]
for sent in text:
sent = sent.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
fake_inputs = self.gpt.compute_embeddings( fake_inputs = self.gpt.compute_embeddings(
gpt_cond_latent.to(self.device), gpt_cond_latent.to(self.device),

View File

@ -54,3 +54,4 @@ encodec==0.1.*
# deps for XTTS # deps for XTTS
unidecode==1.3.* unidecode==1.3.*
num2words num2words
spacy[ja]>=3