Add sentence splitting (#3227)

* Add sentence spliting

* update requirements

* update default args v2

* Add spanish

* Fix return gpt_latents

* Update requirements

* Fix requirements
This commit is contained in:
Julian Weber 2023-11-16 11:01:11 +01:00 committed by GitHub
parent 3c2d5a9e03
commit 675f983550
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 190 additions and 114 deletions

View File

@ -1,10 +1,10 @@
import json
import os
import re
from functools import cached_property
import pypinyin
import torch
import pypinyin
import textwrap
from functools import cached_property
from hangul_romanize import Transliter
from hangul_romanize.rule import academic
from num2words import num2words
@ -12,6 +12,61 @@ from tokenizers import Tokenizer
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
from spacy.lang.en import English
from spacy.lang.zh import Chinese
from spacy.lang.ja import Japanese
from spacy.lang.ar import Arabic
from spacy.lang.es import Spanish
def get_spacy_lang(lang):
if lang == "zh":
return Chinese()
elif lang == "ja":
return Japanese()
elif lang == "ar":
return Arabic()
elif lang == "es":
return Spanish()
else:
# For most languages, Enlish does the job
return English()
def split_sentence(text, lang, text_split_length=250):
"""Preprocess the input text"""
text_splits = []
if text_split_length is not None and len(text) >= text_split_length:
text_splits.append("")
nlp = get_spacy_lang(lang)
nlp.add_pipe("sentencizer")
doc = nlp(text)
for sentence in doc.sents:
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
# if the last sentence + the current sentence is less than the text_split_length
# then add the current sentence to the last sentence
text_splits[-1] += " " + str(sentence)
text_splits[-1] = text_splits[-1].lstrip()
elif len(str(sentence)) > text_split_length:
# if the current sentence is greater than the text_split_length
for line in textwrap.wrap(
str(sentence),
width=text_split_length,
drop_whitespace=True,
break_on_hyphens=False,
tabsize=1,
):
text_splits.append(str(line))
else:
text_splits.append(str(sentence))
if len(text_splits) > 1:
if text_splits[0] == "":
del text_splits[0]
else:
text_splits = [text.lstrip()]
return text_splits
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
@ -464,7 +519,7 @@ def _expand_number(m, lang="en"):
def expand_numbers_multilingual(text, lang="en"):
if lang == "zh" or lang == "zh-cn":
if lang == "zh":
text = zh_num2words()(text)
else:
if lang in ["en", "ru"]:
@ -525,7 +580,7 @@ def japanese_cleaners(text, katsu):
return text
def korean_cleaners(text):
def korean_transliterate(text):
r = Transliter(academic)
return r.translit(text)
@ -546,7 +601,7 @@ class VoiceBpeTokenizer:
"it": 213,
"pt": 203,
"pl": 224,
"zh-cn": 82,
"zh": 82,
"ar": 166,
"cs": 186,
"ru": 182,
@ -571,19 +626,20 @@ class VoiceBpeTokenizer:
)
def preprocess_text(self, txt, lang):
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "ko"}:
txt = multilingual_cleaners(txt, lang)
if lang in {"zh", "zh-cn"}:
if lang == "zh":
txt = chinese_transliterate(txt)
if lang == "ko":
txt = korean_transliterate(txt)
elif lang == "ja":
txt = japanese_cleaners(txt, self.katsu)
elif lang == "ko":
txt = korean_cleaners(txt)
else:
raise NotImplementedError(f"Language '{lang}' is not supported.")
return txt
def encode(self, txt, lang):
lang = lang.split("-")[0] # remove the region
self.check_input_length(txt, lang)
txt = self.preprocess_text(txt, lang)
txt = f"[{lang}]{txt}"

View File

@ -10,7 +10,7 @@ from coqpit import Coqpit
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
@ -420,9 +420,9 @@ class Xtts(BaseTTS):
ref_audio_path,
language,
# GPT inference
temperature=0.65,
length_penalty=1,
repetition_penalty=2.0,
temperature=0.75,
length_penalty=1.0,
repetition_penalty=10.0,
top_k=50,
top_p=0.85,
do_sample=True,
@ -502,71 +502,78 @@ class Xtts(BaseTTS):
gpt_cond_latent,
speaker_embedding,
# GPT inference
temperature=0.65,
length_penalty=1,
repetition_penalty=2.0,
temperature=0.75,
length_penalty=1.0,
repetition_penalty=10.0,
top_k=50,
top_p=0.85,
do_sample=True,
num_beams=1,
speed=1.0,
enable_text_splitting=False,
**hf_generate_kwargs,
):
language = language.split("-")[0] # remove the country code
length_scale = 1.0 / max(speed, 0.05)
text = text.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
if enable_text_splitting:
text = split_sentence(text, language, self.tokenizer.char_limits[language])
else:
text = [text]
# print(" > Input text: ", text)
# print(" > Input text preprocessed: ",self.tokenizer.preprocess_text(text, language))
# print(" > Input tokens: ", text_tokens)
# print(" > Decoded text: ", self.tokenizer.decode(text_tokens[0].cpu().numpy()))
assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
wavs = []
gpt_latents_list = []
for sent in text:
sent = sent.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
with torch.no_grad():
gpt_codes = self.gpt.generate(
cond_latents=gpt_cond_latent,
text_inputs=text_tokens,
input_tokens=None,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=self.gpt_batch_size,
num_beams=num_beams,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
output_attentions=False,
**hf_generate_kwargs,
)
expected_output_len = torch.tensor(
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
)
assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
gpt_latents = self.gpt(
text_tokens,
text_len,
gpt_codes,
expected_output_len,
cond_latents=gpt_cond_latent,
return_attentions=False,
return_latent=True,
)
with torch.no_grad():
gpt_codes = self.gpt.generate(
cond_latents=gpt_cond_latent,
text_inputs=text_tokens,
input_tokens=None,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=self.gpt_batch_size,
num_beams=num_beams,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
output_attentions=False,
**hf_generate_kwargs,
)
expected_output_len = torch.tensor(
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
)
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2),
scale_factor=length_scale,
mode="linear"
).transpose(1, 2)
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
gpt_latents = self.gpt(
text_tokens,
text_len,
gpt_codes,
expected_output_len,
cond_latents=gpt_cond_latent,
return_attentions=False,
return_latent=True,
)
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2),
scale_factor=length_scale,
mode="linear"
).transpose(1, 2)
gpt_latents_list.append(gpt_latents.cpu())
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
return {
"wav": wav.cpu().numpy().squeeze(),
"gpt_latents": gpt_latents,
"wav": torch.cat(wavs, dim=0).numpy(),
"gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(),
"speaker_embedding": speaker_embedding,
}
@ -606,66 +613,78 @@ class Xtts(BaseTTS):
stream_chunk_size=20,
overlap_wav_len=1024,
# GPT inference
temperature=0.65,
length_penalty=1,
repetition_penalty=2.0,
temperature=0.75,
length_penalty=1.0,
repetition_penalty=10.0,
top_k=50,
top_p=0.85,
do_sample=True,
speed=1.0,
enable_text_splitting=False,
**hf_generate_kwargs,
):
language = language.split("-")[0] # remove the country code
length_scale = 1.0 / max(speed, 0.05)
text = text.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
if enable_text_splitting:
text = split_sentence(text, language, self.tokenizer.char_limits[language])
else:
text = [text]
fake_inputs = self.gpt.compute_embeddings(
gpt_cond_latent.to(self.device),
text_tokens,
)
gpt_generator = self.gpt.get_generator(
fake_inputs=fake_inputs,
top_k=top_k,
top_p=top_p,
temperature=temperature,
do_sample=do_sample,
num_beams=1,
num_return_sequences=1,
length_penalty=float(length_penalty),
repetition_penalty=float(repetition_penalty),
output_attentions=False,
output_hidden_states=True,
**hf_generate_kwargs,
)
for sent in text:
sent = sent.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
last_tokens = []
all_latents = []
wav_gen_prev = None
wav_overlap = None
is_end = False
assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
while not is_end:
try:
x, latent = next(gpt_generator)
last_tokens += [x]
all_latents += [latent]
except StopIteration:
is_end = True
fake_inputs = self.gpt.compute_embeddings(
gpt_cond_latent.to(self.device),
text_tokens,
)
gpt_generator = self.gpt.get_generator(
fake_inputs=fake_inputs,
top_k=top_k,
top_p=top_p,
temperature=temperature,
do_sample=do_sample,
num_beams=1,
num_return_sequences=1,
length_penalty=float(length_penalty),
repetition_penalty=float(repetition_penalty),
output_attentions=False,
output_hidden_states=True,
**hf_generate_kwargs,
)
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2),
scale_factor=length_scale,
mode="linear"
).transpose(1, 2)
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
)
last_tokens = []
yield wav_chunk
last_tokens = []
all_latents = []
wav_gen_prev = None
wav_overlap = None
is_end = False
while not is_end:
try:
x, latent = next(gpt_generator)
last_tokens += [x]
all_latents += [latent]
except StopIteration:
is_end = True
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2),
scale_factor=length_scale,
mode="linear"
).transpose(1, 2)
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
)
last_tokens = []
yield wav_chunk
def forward(self):
raise NotImplementedError(

View File

@ -54,3 +54,4 @@ encodec==0.1.*
# deps for XTTS
unidecode==1.3.*
num2words
spacy[ja]>=3