Make style

This commit is contained in:
Eren G??lge 2023-11-06 11:13:09 +01:00
parent d045bfce41
commit b094979f1a
10 changed files with 718 additions and 692 deletions

View File

@ -84,7 +84,24 @@ class XttsConfig(BaseTTSConfig):
audio: XttsAudioConfig = field(default_factory=XttsAudioConfig) audio: XttsAudioConfig = field(default_factory=XttsAudioConfig)
model_dir: str = None model_dir: str = None
languages: List[str] = field( languages: List[str] = field(
default_factory=lambda: ["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "hu", "ko", "ja"] default_factory=lambda: [
"en",
"es",
"fr",
"de",
"it",
"pt",
"pl",
"tr",
"ru",
"nl",
"cs",
"ar",
"zh-cn",
"hu",
"ko",
"ja",
]
) )
# inference params # inference params

View File

@ -13,6 +13,7 @@ from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
def null_position_embeddings(range, dim): def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
@ -186,7 +187,9 @@ class GPT(nn.Module):
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
return { return {
"conditioning_encoder": list(self.conditioning_encoder.parameters()), "conditioning_encoder": list(self.conditioning_encoder.parameters()),
"conditioning_perceiver": list(self.conditioning_perceiver.parameters()) if self.use_perceiver_resampler else None, "conditioning_perceiver": list(self.conditioning_perceiver.parameters())
if self.use_perceiver_resampler
else None,
"gpt": list(self.gpt.parameters()), "gpt": list(self.gpt.parameters()),
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()), "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
} }
@ -355,9 +358,9 @@ class GPT(nn.Module):
if not return_latent: if not return_latent:
if cond_input.ndim == 4: if cond_input.ndim == 4:
cond_input = cond_input.squeeze(1) cond_input = cond_input.squeeze(1)
conds = self.conditioning_encoder(cond_input) # (b, d, s) conds = self.conditioning_encoder(cond_input) # (b, d, s)
if self.use_perceiver_resampler: if self.use_perceiver_resampler:
conds = self.conditioning_perceiver(conds.permute(0, 2, 1)).transpose(1, 2) # (b, d, 32) conds = self.conditioning_perceiver(conds.permute(0, 2, 1)).transpose(1, 2) # (b, d, 32)
else: else:
# already computed # already computed
conds = cond_input.unsqueeze(1) conds = cond_input.unsqueeze(1)

View File

@ -16,8 +16,10 @@ from einops.layers.torch import Rearrange
def exists(val): def exists(val):
return val is not None return val is not None
def once(fn): def once(fn):
called = False called = False
@wraps(fn) @wraps(fn)
def inner(x): def inner(x):
nonlocal called nonlocal called
@ -25,19 +27,17 @@ def once(fn):
return return
called = True called = True
return fn(x) return fn(x)
return inner return inner
print_once = once(print) print_once = once(print)
# main class # main class
class Attend(nn.Module): class Attend(nn.Module):
def __init__( def __init__(self, dropout=0.0, causal=False, use_flash=False):
self,
dropout = 0.,
causal = False,
use_flash = False
):
super().__init__() super().__init__()
self.dropout = dropout self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout) self.attn_dropout = nn.Dropout(dropout)
@ -46,23 +46,25 @@ class Attend(nn.Module):
self.register_buffer("mask", None, persistent=False) self.register_buffer("mask", None, persistent=False)
self.use_flash = use_flash self.use_flash = use_flash
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' assert not (
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
), "in order to use flash attention, you must be using pytorch 2.0 or above"
# determine efficient attention configs for cuda and cpu # determine efficient attention configs for cuda and cpu
self.config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])
self.cpu_config = self.config(True, True, True) self.cpu_config = self.config(True, True, True)
self.cuda_config = None self.cuda_config = None
if not torch.cuda.is_available() or not use_flash: if not torch.cuda.is_available() or not use_flash:
return return
device_properties = torch.cuda.get_device_properties(torch.device('cuda')) device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
if device_properties.major == 8 and device_properties.minor == 0: if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda') print_once("A100 GPU detected, using flash attention if input tensor is on cuda")
self.cuda_config = self.config(True, False, False) self.cuda_config = self.config(True, False, False)
else: else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda")
self.cuda_config = self.config(False, True, True) self.cuda_config = self.config(False, True, True)
def get_mask(self, n, device): def get_mask(self, n, device):
@ -73,23 +75,23 @@ class Attend(nn.Module):
self.register_buffer("mask", mask, persistent=False) self.register_buffer("mask", mask, persistent=False)
return mask return mask
def flash_attn(self, q, k, v, mask = None): def flash_attn(self, q, k, v, mask=None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
# Recommended for multi-query single-key-value attention by Tri Dao # Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
if k.ndim == 3: if k.ndim == 3:
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
if v.ndim == 3: if v.ndim == 3:
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
# Check if mask exists and expand to compatible shape # Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L # The mask is B L, so it would have to be expanded to B H N L
if exists(mask): if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j') mask = rearrange(mask, "b j -> b 1 1 j")
mask = mask.expand(-1, heads, q_len, -1) mask = mask.expand(-1, heads, q_len, -1)
# Check if there is a compatible device for flash attention # Check if there is a compatible device for flash attention
@ -100,15 +102,12 @@ class Attend(nn.Module):
with torch.backends.cuda.sdp_kernel(**config._asdict()): with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention( out = F.scaled_dot_product_attention(
q, k, v, q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
) )
return out return out
def forward(self, q, k, v, mask = None): def forward(self, q, k, v, mask=None):
""" """
einstein notation einstein notation
b - batch b - batch
@ -122,9 +121,9 @@ class Attend(nn.Module):
scale = q.shape[-1] ** -0.5 scale = q.shape[-1] ** -0.5
if self.use_flash: if self.use_flash:
return self.flash_attn(q, k, v, mask = mask) return self.flash_attn(q, k, v, mask=mask)
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
# similarity # similarity
@ -133,7 +132,7 @@ class Attend(nn.Module):
# key padding mask # key padding mask
if exists(mask): if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j') mask = rearrange(mask, "b j -> b 1 1 j")
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# causal mask # causal mask
@ -153,6 +152,7 @@ class Attend(nn.Module):
return out return out
def Sequential(*mods): def Sequential(*mods):
return nn.Sequential(*filter(exists, mods)) return nn.Sequential(*filter(exists, mods))

View File

@ -161,9 +161,9 @@ _abbreviations = {
"hu": [ "hu": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [ for x in [
("dr", "doktor"), # doctor ("dr", "doktor"), # doctor
("b", "bácsi"), # Mr. ("b", "bácsi"), # Mr.
("nőv", "nővér"), # nurse ("nőv", "nővér"), # nurse
# Add other Hungarian abbreviations here if needed. # Add other Hungarian abbreviations here if needed.
] ]
], ],
@ -171,9 +171,8 @@ _abbreviations = {
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [ for x in [
# Korean doesn't typically use abbreviations in the same way as Latin-based scripts. # Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
] ]
] ],
} }
@ -354,7 +353,7 @@ _symbols_multilingual = {
("#", " kettőskereszt "), ("#", " kettőskereszt "),
("$", " dollár "), ("$", " dollár "),
("£", " font "), ("£", " font "),
("°", " fok ") ("°", " fok "),
] ]
], ],
"ko": [ "ko": [
@ -367,9 +366,9 @@ _symbols_multilingual = {
("#", " 번호 "), ("#", " 번호 "),
("$", " 달러 "), ("$", " 달러 "),
("£", " 파운드 "), ("£", " 파운드 "),
("°", "") ("°", ""),
] ]
] ],
} }
@ -463,14 +462,9 @@ def _expand_ordinal(m, lang="en"):
def _expand_number(m, lang="en"): def _expand_number(m, lang="en"):
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz") return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
<<<<<<< HEAD
def expand_numbers_multilingual(text, lang="en"): def expand_numbers_multilingual(text, lang="en"):
if lang == "zh-cn":
=======
def expand_numbers_multilingual(text, lang='en'):
if lang == "zh" or lang == "zh-cn": if lang == "zh" or lang == "zh-cn":
>>>>>>> Update model entry
text = zh_num2words()(text) text = zh_num2words()(text)
else: else:
if lang in ["en", "ru"]: if lang in ["en", "ru"]:
@ -521,24 +515,15 @@ def basic_cleaners(text):
def chinese_transliterate(text): def chinese_transliterate(text):
return "".join( return "".join(
p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True) [p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
) )
def japanese_cleaners(text, katsu): def japanese_cleaners(text, katsu):
text = katsu.romaji(text) text = katsu.romaji(text)
text = lowercase(text) text = lowercase(text)
return text return text
<<<<<<< HEAD
class VoiceBpeTokenizer:
def __init__(self, vocab_file=None, preprocess=None):
self.tokenizer = None
self.katsu = None
=======
>>>>>>> Update model entry
def korean_cleaners(text): def korean_cleaners(text):
r = Transliter(academic) r = Transliter(academic)
@ -559,32 +544,14 @@ def preprocess_text(txt, lang):
return txt return txt
<<<<<<< HEAD DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json")
def preprocess_text(self, txt, lang):
if lang in ["en", "es", "fr", "de", "pt", "it", "pl", "ar", "cs", "ru", "nl", "tr", "zh-cn"]:
txt = multilingual_cleaners(txt, lang)
if lang == "zh-cn":
txt = chinese_transliterate(txt)
elif lang == "ja":
if self.katsu is None:
import cutlet
self.katsu = cutlet.Cutlet()
txt = japanese_cleaners(txt, self.katsu)
else:
raise NotImplementedError()
return txt
=======
DEFAULT_VOCAB_FILE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json"
)
class VoiceBpeTokenizer: class VoiceBpeTokenizer:
def __init__(self, vocab_file=None): def __init__(self, vocab_file=None):
self.tokenizer = None self.tokenizer = None
if vocab_file is not None: if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file) self.tokenizer = Tokenizer.from_file(vocab_file)
>>>>>>> Update model entry
def encode(self, txt, lang): def encode(self, txt, lang):
txt = preprocess_text(txt, lang) txt = preprocess_text(txt, lang)
@ -611,139 +578,145 @@ class VoiceBpeTokenizer:
def test_expand_numbers_multilingual(): def test_expand_numbers_multilingual():
test_cases = [ test_cases = [
# English # English
("In 12.5 seconds.", 'In twelve point five seconds.', 'en'), ("In 12.5 seconds.", "In twelve point five seconds.", "en"),
("There were 50 soldiers.", 'There were fifty soldiers.', 'en'), ("There were 50 soldiers.", "There were fifty soldiers.", "en"),
("This is a 1st test", 'This is a first test', 'en'), ("This is a 1st test", "This is a first test", "en"),
("That will be $20 sir.", 'That will be twenty dollars sir.', 'en'), ("That will be $20 sir.", "That will be twenty dollars sir.", "en"),
("That will be 20€ sir.", 'That will be twenty euro sir.', 'en'), ("That will be 20€ sir.", "That will be twenty euro sir.", "en"),
("That will be 20.15€ sir.", 'That will be twenty euro, fifteen cents sir.', 'en'), ("That will be 20.15€ sir.", "That will be twenty euro, fifteen cents sir.", "en"),
("That's 100,000.5.", 'That\'s one hundred thousand point five.', 'en'), ("That's 100,000.5.", "That's one hundred thousand point five.", "en"),
# French # French
("En 12,5 secondes.", 'En douze virgule cinq secondes.', 'fr'), ("En 12,5 secondes.", "En douze virgule cinq secondes.", "fr"),
("Il y avait 50 soldats.", 'Il y avait cinquante soldats.', 'fr'), ("Il y avait 50 soldats.", "Il y avait cinquante soldats.", "fr"),
("Ceci est un 1er test", 'Ceci est un premier test', 'fr'), ("Ceci est un 1er test", "Ceci est un premier test", "fr"),
("Cela vous fera $20 monsieur.", 'Cela vous fera vingt dollars monsieur.', 'fr'), ("Cela vous fera $20 monsieur.", "Cela vous fera vingt dollars monsieur.", "fr"),
("Cela vous fera 20€ monsieur.", 'Cela vous fera vingt euros monsieur.', 'fr'), ("Cela vous fera 20€ monsieur.", "Cela vous fera vingt euros monsieur.", "fr"),
("Cela vous fera 20,15€ monsieur.", 'Cela vous fera vingt euros et quinze centimes monsieur.', 'fr'), ("Cela vous fera 20,15€ monsieur.", "Cela vous fera vingt euros et quinze centimes monsieur.", "fr"),
("Ce sera 100.000,5.", 'Ce sera cent mille virgule cinq.', 'fr'), ("Ce sera 100.000,5.", "Ce sera cent mille virgule cinq.", "fr"),
# German # German
("In 12,5 Sekunden.", 'In zwölf Komma fünf Sekunden.', 'de'), ("In 12,5 Sekunden.", "In zwölf Komma fünf Sekunden.", "de"),
("Es gab 50 Soldaten.", 'Es gab fünfzig Soldaten.', 'de'), ("Es gab 50 Soldaten.", "Es gab fünfzig Soldaten.", "de"),
("Dies ist ein 1. Test", 'Dies ist ein erste Test', 'de'), # Issue with gender ("Dies ist ein 1. Test", "Dies ist ein erste Test", "de"), # Issue with gender
("Das macht $20 Herr.", 'Das macht zwanzig Dollar Herr.', 'de'), ("Das macht $20 Herr.", "Das macht zwanzig Dollar Herr.", "de"),
("Das macht 20€ Herr.", 'Das macht zwanzig Euro Herr.', 'de'), ("Das macht 20€ Herr.", "Das macht zwanzig Euro Herr.", "de"),
("Das macht 20,15€ Herr.", 'Das macht zwanzig Euro und fünfzehn Cent Herr.', 'de'), ("Das macht 20,15€ Herr.", "Das macht zwanzig Euro und fünfzehn Cent Herr.", "de"),
# Spanish # Spanish
("En 12,5 segundos.", 'En doce punto cinco segundos.', 'es'), ("En 12,5 segundos.", "En doce punto cinco segundos.", "es"),
("Había 50 soldados.", 'Había cincuenta soldados.', 'es'), ("Había 50 soldados.", "Había cincuenta soldados.", "es"),
("Este es un 1er test", 'Este es un primero test', 'es'), ("Este es un 1er test", "Este es un primero test", "es"),
("Eso le costará $20 señor.", 'Eso le costará veinte dólares señor.', 'es'), ("Eso le costará $20 señor.", "Eso le costará veinte dólares señor.", "es"),
("Eso le costará 20€ señor.", 'Eso le costará veinte euros señor.', 'es'), ("Eso le costará 20€ señor.", "Eso le costará veinte euros señor.", "es"),
("Eso le costará 20,15€ señor.", 'Eso le costará veinte euros con quince céntimos señor.', 'es'), ("Eso le costará 20,15€ señor.", "Eso le costará veinte euros con quince céntimos señor.", "es"),
# Italian # Italian
("In 12,5 secondi.", 'In dodici virgola cinque secondi.', 'it'), ("In 12,5 secondi.", "In dodici virgola cinque secondi.", "it"),
("C'erano 50 soldati.", "C'erano cinquanta soldati.", 'it'), ("C'erano 50 soldati.", "C'erano cinquanta soldati.", "it"),
("Questo è un 1° test", 'Questo è un primo test', 'it'), ("Questo è un 1° test", "Questo è un primo test", "it"),
("Ti costerà $20 signore.", 'Ti costerà venti dollari signore.', 'it'), ("Ti costerà $20 signore.", "Ti costerà venti dollari signore.", "it"),
("Ti costerà 20€ signore.", 'Ti costerà venti euro signore.', 'it'), ("Ti costerà 20€ signore.", "Ti costerà venti euro signore.", "it"),
("Ti costerà 20,15€ signore.", 'Ti costerà venti euro e quindici centesimi signore.', 'it'), ("Ti costerà 20,15€ signore.", "Ti costerà venti euro e quindici centesimi signore.", "it"),
# Portuguese # Portuguese
("Em 12,5 segundos.", 'Em doze vírgula cinco segundos.', 'pt'), ("Em 12,5 segundos.", "Em doze vírgula cinco segundos.", "pt"),
("Havia 50 soldados.", 'Havia cinquenta soldados.', 'pt'), ("Havia 50 soldados.", "Havia cinquenta soldados.", "pt"),
("Este é um 1º teste", 'Este é um primeiro teste', 'pt'), ("Este é um 1º teste", "Este é um primeiro teste", "pt"),
("Isso custará $20 senhor.", 'Isso custará vinte dólares senhor.', 'pt'), ("Isso custará $20 senhor.", "Isso custará vinte dólares senhor.", "pt"),
("Isso custará 20€ senhor.", 'Isso custará vinte euros senhor.', 'pt'), ("Isso custará 20€ senhor.", "Isso custará vinte euros senhor.", "pt"),
("Isso custará 20,15€ senhor.", 'Isso custará vinte euros e quinze cêntimos senhor.', 'pt'), # "cêntimos" should be "centavos" num2words issue (
"Isso custará 20,15€ senhor.",
"Isso custará vinte euros e quinze cêntimos senhor.",
"pt",
), # "cêntimos" should be "centavos" num2words issue
# Polish # Polish
("W 12,5 sekundy.", 'W dwanaście przecinek pięć sekundy.', 'pl'), ("W 12,5 sekundy.", "W dwanaście przecinek pięć sekundy.", "pl"),
("Było 50 żołnierzy.", 'Było pięćdziesiąt żołnierzy.', 'pl'), ("Było 50 żołnierzy.", "Było pięćdziesiąt żołnierzy.", "pl"),
("To będzie kosztować 20€ panie.", 'To będzie kosztować dwadzieścia euro panie.', 'pl'), ("To będzie kosztować 20€ panie.", "To będzie kosztować dwadzieścia euro panie.", "pl"),
("To będzie kosztować 20,15€ panie.", 'To będzie kosztować dwadzieścia euro, piętnaście centów panie.', 'pl'), ("To będzie kosztować 20,15€ panie.", "To będzie kosztować dwadzieścia euro, piętnaście centów panie.", "pl"),
# Arabic # Arabic
("في الـ 12,5 ثانية.", 'في الـ اثنا عشر , خمسون ثانية.', 'ar'), ("في الـ 12,5 ثانية.", "في الـ اثنا عشر , خمسون ثانية.", "ar"),
("كان هناك 50 جنديًا.", 'كان هناك خمسون جنديًا.', 'ar'), ("كان هناك 50 جنديًا.", "كان هناك خمسون جنديًا.", "ar"),
# ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words # ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words
# ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'), # ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'),
# Czech # Czech
("Za 12,5 vteřiny.", 'Za dvanáct celá pět vteřiny.', 'cs'), ("Za 12,5 vteřiny.", "Za dvanáct celá pět vteřiny.", "cs"),
("Bylo tam 50 vojáků.", 'Bylo tam padesát vojáků.', 'cs'), ("Bylo tam 50 vojáků.", "Bylo tam padesát vojáků.", "cs"),
("To bude stát 20€ pane.", 'To bude stát dvacet euro pane.', 'cs'), ("To bude stát 20€ pane.", "To bude stát dvacet euro pane.", "cs"),
("To bude 20.15€ pane.", 'To bude dvacet euro, patnáct centů pane.', 'cs'), ("To bude 20.15€ pane.", "To bude dvacet euro, patnáct centů pane.", "cs"),
# Russian # Russian
("Через 12.5 секунды.", 'Через двенадцать запятая пять секунды.', 'ru'), ("Через 12.5 секунды.", "Через двенадцать запятая пять секунды.", "ru"),
("Там было 50 солдат.", 'Там было пятьдесят солдат.', 'ru'), ("Там было 50 солдат.", "Там было пятьдесят солдат.", "ru"),
("Это будет 20.15€ сэр.", 'Это будет двадцать евро, пятнадцать центов сэр.', 'ru'), ("Это будет 20.15€ сэр.", "Это будет двадцать евро, пятнадцать центов сэр.", "ru"),
("Это будет стоить 20€ господин.", 'Это будет стоить двадцать евро господин.', 'ru'), ("Это будет стоить 20€ господин.", "Это будет стоить двадцать евро господин.", "ru"),
# Dutch # Dutch
("In 12,5 seconden.", 'In twaalf komma vijf seconden.', 'nl'), ("In 12,5 seconden.", "In twaalf komma vijf seconden.", "nl"),
("Er waren 50 soldaten.", 'Er waren vijftig soldaten.', 'nl'), ("Er waren 50 soldaten.", "Er waren vijftig soldaten.", "nl"),
("Dat wordt dan $20 meneer.", 'Dat wordt dan twintig dollar meneer.', 'nl'), ("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
("Dat wordt dan 20€ meneer.", 'Dat wordt dan twintig euro meneer.', 'nl'), ("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
# Chinese (Simplified) # Chinese (Simplified)
("在12.5秒内", '在十二点五秒内', 'zh'), ("在12.5秒内", "在十二点五秒内", "zh"),
("有50名士兵", '有五十名士兵', 'zh'), ("有50名士兵", "有五十名士兵", "zh"),
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work # ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
# ("那将是20€先生", '那将是二十欧元先生', 'zh'), # ("那将是20€先生", '那将是二十欧元先生', 'zh'),
# Turkish # Turkish
# ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR # ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR
("50 asker vardı.", 'elli asker vardı.', 'tr'), ("50 asker vardı.", "elli asker vardı.", "tr"),
("Bu 1. test", 'Bu birinci test', 'tr'), ("Bu 1. test", "Bu birinci test", "tr"),
# ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'), # ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'),
# Hungarian # Hungarian
("12,5 másodperc alatt.", 'tizenkettő egész öt tized másodperc alatt.', 'hu'), ("12,5 másodperc alatt.", "tizenkettő egész öt tized másodperc alatt.", "hu"),
("50 katona volt.", 'ötven katona volt.', 'hu'), ("50 katona volt.", "ötven katona volt.", "hu"),
("Ez az 1. teszt", 'Ez az első teszt', 'hu'), ("Ez az 1. teszt", "Ez az első teszt", "hu"),
# Korean # Korean
("12.5 초 안에.", '십이 점 다섯 초 안에.', 'ko'), ("12.5 초 안에.", "십이 점 다섯 초 안에.", "ko"),
("50 명의 병사가 있었다.", '오십 명의 병사가 있었다.', 'ko'), ("50 명의 병사가 있었다.", "오십 명의 병사가 있었다.", "ko"),
("이것은 1 번째 테스트입니다", '이것은 첫 번째 테스트입니다', 'ko'), ("이것은 1 번째 테스트입니다", "이것은 첫 번째 테스트입니다", "ko"),
] ]
for a, b, lang in test_cases: for a, b, lang in test_cases:
out = expand_numbers_multilingual(a, lang=lang) out = expand_numbers_multilingual(a, lang=lang)
assert out == b, f"'{out}' vs '{b}'" assert out == b, f"'{out}' vs '{b}'"
def test_abbreviations_multilingual(): def test_abbreviations_multilingual():
test_cases = [ test_cases = [
# English # English
("Hello Mr. Smith.", 'Hello mister Smith.', 'en'), ("Hello Mr. Smith.", "Hello mister Smith.", "en"),
("Dr. Jones is here.", 'doctor Jones is here.', 'en'), ("Dr. Jones is here.", "doctor Jones is here.", "en"),
# Spanish # Spanish
("Hola Sr. Garcia.", 'Hola señor Garcia.', 'es'), ("Hola Sr. Garcia.", "Hola señor Garcia.", "es"),
("La Dra. Martinez es muy buena.", 'La doctora Martinez es muy buena.', 'es'), ("La Dra. Martinez es muy buena.", "La doctora Martinez es muy buena.", "es"),
# French # French
("Bonjour Mr. Dupond.", 'Bonjour monsieur Dupond.', 'fr'), ("Bonjour Mr. Dupond.", "Bonjour monsieur Dupond.", "fr"),
("Mme. Moreau est absente aujourd'hui.", 'madame Moreau est absente aujourd\'hui.', 'fr'), ("Mme. Moreau est absente aujourd'hui.", "madame Moreau est absente aujourd'hui.", "fr"),
# German # German
("Frau Dr. Müller ist sehr klug.", 'Frau doktor Müller ist sehr klug.', 'de'), ("Frau Dr. Müller ist sehr klug.", "Frau doktor Müller ist sehr klug.", "de"),
# Portuguese # Portuguese
("Olá Sr. Silva.", 'Olá senhor Silva.', 'pt'), ("Olá Sr. Silva.", "Olá senhor Silva.", "pt"),
("Dra. Costa, você está disponível?", 'doutora Costa, você está disponível?', 'pt'), ("Dra. Costa, você está disponível?", "doutora Costa, você está disponível?", "pt"),
# Italian # Italian
("Buongiorno, Sig. Rossi.", 'Buongiorno, signore Rossi.', 'it'), ("Buongiorno, Sig. Rossi.", "Buongiorno, signore Rossi.", "it"),
#("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern # ("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern
# Polish # Polish
("Dzień dobry, P. Kowalski.", 'Dzień dobry, pani Kowalski.', 'pl'), ("Dzień dobry, P. Kowalski.", "Dzień dobry, pani Kowalski.", "pl"),
("M. Nowak, czy mogę zadać pytanie?", 'pan Nowak, czy mogę zadać pytanie?', 'pl'), ("M. Nowak, czy mogę zadać pytanie?", "pan Nowak, czy mogę zadać pytanie?", "pl"),
# Czech # Czech
("P. Novák", "pan Novák", 'cs'), ("P. Novák", "pan Novák", "cs"),
("Dr. Vojtěch", "doktor Vojtěch", 'cs'), ("Dr. Vojtěch", "doktor Vojtěch", "cs"),
# Dutch # Dutch
("Dhr. Jansen", "de heer Jansen", 'nl'), ("Dhr. Jansen", "de heer Jansen", "nl"),
("Mevr. de Vries", "mevrouw de Vries", 'nl'), ("Mevr. de Vries", "mevrouw de Vries", "nl"),
# Russian # Russian
("Здравствуйте Г-н Иванов.", "Здравствуйте господин Иванов.", 'ru'), ("Здравствуйте Г-н Иванов.", "Здравствуйте господин Иванов.", "ru"),
("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", 'ru'), ("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", "ru"),
# Turkish # Turkish
("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", 'tr'), ("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", "tr"),
("Dr. Ayşe burada.", "doktor Ayşe burada.", 'tr'), ("Dr. Ayşe burada.", "doktor Ayşe burada.", "tr"),
# Hungarian # Hungarian
("Dr. Szabó itt van.", "doktor Szabó itt van.", 'hu'), ("Dr. Szabó itt van.", "doktor Szabó itt van.", "hu"),
] ]
for a, b, lang in test_cases: for a, b, lang in test_cases:
out = expand_abbreviations_multilingual(a, lang=lang) out = expand_abbreviations_multilingual(a, lang=lang)
assert out == b, f"'{out}' vs '{b}'" assert out == b, f"'{out}' vs '{b}'"
def test_symbols_multilingual(): def test_symbols_multilingual():
test_cases = [ test_cases = [
("I have 14% battery", "I have 14 percent battery", "en"), ("I have 14% battery", "I have 14 percent battery", "en"),
@ -763,14 +736,15 @@ def test_symbols_multilingual():
("我的电量为 14%", "我的电量为 14 百分之", "zh"), ("我的电量为 14%", "我的电量为 14 百分之", "zh"),
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"), ("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"), ("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko") ("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
] ]
for a, b, lang in test_cases: for a, b, lang in test_cases:
out = expand_symbols_multilingual(a, lang=lang) out = expand_symbols_multilingual(a, lang=lang)
assert out == b, f"'{out}' vs '{b}'" assert out == b, f"'{out}' vs '{b}'"
if __name__ == "__main__": if __name__ == "__main__":
test_expand_numbers_multilingual() test_expand_numbers_multilingual()
test_abbreviations_multilingual() test_abbreviations_multilingual()
test_symbols_multilingual() test_symbols_multilingual()

View File

@ -149,7 +149,11 @@ class XTTSDataset(torch.utils.data.Dataset):
# if use masking do not use cond_len # if use masking do not use cond_len
cond_len = torch.nan cond_len = torch.nan
else: else:
ref_sample = sample["reference_path"] if "reference_path" in sample and sample["reference_path"] is not None else audiopath ref_sample = (
sample["reference_path"]
if "reference_path" in sample and sample["reference_path"] is not None
else audiopath
)
cond, cond_len, _ = get_prompt_slice( cond, cond_len, _ = get_prompt_slice(
ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
) )
@ -210,7 +214,9 @@ class XTTSDataset(torch.utils.data.Dataset):
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long), "wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
"filenames": audiopath, "filenames": audiopath,
"conditioning": cond.unsqueeze(1), "conditioning": cond.unsqueeze(1),
"cond_lens": torch.tensor(cond_len, dtype=torch.long) if cond_len is not torch.nan else torch.tensor([cond_len]), "cond_lens": torch.tensor(cond_len, dtype=torch.long)
if cond_len is not torch.nan
else torch.tensor([cond_len]),
"cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_idxs]), "cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_idxs]),
} }
return res return res

View File

@ -213,7 +213,13 @@ class GPTTrainer(BaseTTS):
cond_lens: long tensor, (b,) cond_lens: long tensor, (b,)
""" """
losses = self.xtts.gpt( losses = self.xtts.gpt(
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs, cond_lens=cond_lens, text_inputs,
text_lengths,
audio_codes,
wav_lengths,
cond_mels=cond_mels,
cond_idxs=cond_idxs,
cond_lens=cond_lens,
) )
return losses return losses
@ -227,7 +233,12 @@ class GPTTrainer(BaseTTS):
print(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
for idx, s_info in enumerate(self.config.test_sentences): for idx, s_info in enumerate(self.config.test_sentences):
wav = self.xtts.synthesize( wav = self.xtts.synthesize(
s_info["text"], self.config, s_info["speaker_wav"], s_info["language"], gpt_cond_len=3, decoder="ne_hifigan" s_info["text"],
self.config,
s_info["speaker_wav"],
s_info["language"],
gpt_cond_len=3,
decoder="ne_hifigan",
)["wav"] )["wav"]
test_audios["{}-audio".format(idx)] = wav test_audios["{}-audio".format(idx)] = wav
@ -295,7 +306,9 @@ class GPTTrainer(BaseTTS):
cond_idxs = batch["cond_idxs"] cond_idxs = batch["cond_idxs"]
cond_lens = batch["cond_lens"] cond_lens = batch["cond_lens"]
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens) loss_text, loss_mel, _ = self.forward(
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens
)
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"] loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]

File diff suppressed because it is too large Load Diff

View File

@ -381,7 +381,8 @@ class Xtts(BaseTTS):
audio_22k = torchaudio.functional.resample(audio, sr, 22050) audio_22k = torchaudio.functional.resample(audio, sr, 22050)
audio_22k = audio_22k[:, : 22050 * length] audio_22k = audio_22k[:, : 22050 * length]
if self.args.gpt_use_perceiver_resampler: if self.args.gpt_use_perceiver_resampler:
mel = wav_to_mel_cloning(audio_22k, mel = wav_to_mel_cloning(
audio_22k,
mel_norms=self.mel_stats.cpu(), mel_norms=self.mel_stats.cpu(),
n_fft=2048, n_fft=2048,
hop_length=256, hop_length=256,
@ -391,10 +392,11 @@ class Xtts(BaseTTS):
sample_rate=22050, sample_rate=22050,
f_min=0, f_min=0,
f_max=8000, f_max=8000,
n_mels=80 n_mels=80,
) )
else: else:
mel = wav_to_mel_cloning(audio_22k, mel = wav_to_mel_cloning(
audio_22k,
mel_norms=self.mel_stats.cpu(), mel_norms=self.mel_stats.cpu(),
n_fft=4096, n_fft=4096,
hop_length=1024, hop_length=1024,
@ -404,7 +406,7 @@ class Xtts(BaseTTS):
sample_rate=22050, sample_rate=22050,
f_min=0, f_min=0,
f_max=8000, f_max=8000,
n_mels=80 n_mels=80,
) )
cond_latent = self.gpt.get_style_emb(mel.to(self.device)) cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2) return cond_latent.transpose(1, 2)
@ -598,7 +600,10 @@ class Xtts(BaseTTS):
Sample rate is 24kHz. Sample rate is 24kHz.
""" """
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents( (gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len, max_ref_length=max_ref_len, sound_norm_refs=sound_norm_refs audio_path=ref_audio_path,
gpt_cond_len=gpt_cond_len,
max_ref_length=max_ref_len,
sound_norm_refs=sound_norm_refs,
) )
return self.inference( return self.inference(
@ -728,7 +733,12 @@ class Xtts(BaseTTS):
) )
wav = self.vocoder.inference(mel) wav = self.vocoder.inference(mel)
return {"wav": wav.cpu().numpy().squeeze(), "gpt_latents": gpt_latents, "speaker_embedding": speaker_embedding, "diffusion_conditioning": diffusion_conditioning} return {
"wav": wav.cpu().numpy().squeeze(),
"gpt_latents": gpt_latents,
"speaker_embedding": speaker_embedding,
"diffusion_conditioning": diffusion_conditioning,
}
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
"""Handle chunk formatting in streaming mode""" """Handle chunk formatting in streaming mode"""

View File

@ -61,13 +61,15 @@ TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/voca
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/model.pth" XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/model.pth"
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file
# download XTTS v2.0 files if needed # download XTTS v2.0 files if needed
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
print(" > Downloading XTTS v2.0 files!") print(" > Downloading XTTS v2.0 files!")
ModelManager._download_model_files([TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) ModelManager._download_model_files(
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
)
# Training sentences generations # Training sentences generations
@ -92,7 +94,7 @@ def main():
gpt_num_audio_tokens=8194, gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192, gpt_start_audio_token=8192,
gpt_stop_audio_token=8193, gpt_stop_audio_token=8193,
use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint
gpt_use_masking_gt_prompt_approach=True, gpt_use_masking_gt_prompt_approach=True,
gpt_use_perceiver_resampler=True, gpt_use_perceiver_resampler=True,
) )

View File

@ -22,4 +22,8 @@ def test_synthesize():
) )
# test pipe_out command # test pipe_out command
<<<<<<< HEAD
run_cli(f'tts --text "test." --pipe_out --out_path "{output_path}" | aplay') run_cli(f'tts --text "test." --pipe_out --out_path "{output_path}" | aplay')
=======
run_cli('tts --text "test." --pipe_out ' f'--out_path "{output_path}" | aplay')
>>>>>>> Make style