* Add support for ne_hifigan

* Update model.json

* Update hash

* Fix model loading

* Enhance text_normalization

* Add xtts to zoo test exception

* Add model hash check

* Add get_number_tokens
This commit is contained in:
Julian Weber 2023-10-20 16:02:08 +02:00 committed by GitHub
parent 747f688dc3
commit cf97116185
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1726 additions and 193 deletions

View File

@ -15,6 +15,21 @@
"contact": "info@coqui.ai",
"tos_required": true
},
"xtts_v1.1": {
"description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/hash.md5"
],
"model_hash": "10163afc541dc86801b33d1f3217b456",
"default_vocoder": null,
"commit": "82910a63",
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": true
},
"your_tts": {
"description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--multilingual--multi-dataset--your_tts.zip",

View File

@ -1,206 +1,469 @@
import json
import os
import re
import json
import inflect
import pandas as pd
import pypinyin
import torch
from num2words import num2words
from tokenizers import Tokenizer
from unidecode import unidecode
from TTS.tts.utils.text.cleaners import english_cleaners
import pypinyin
import cutlet
from num2words import num2words
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
def _remove_commas(m):
return m.group(1).replace(",", "")
def _expand_decimal_point(m):
return m.group(1).replace(".", " point ")
def _expand_dollars(m):
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return "zero dollars"
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return "two thousand"
elif num > 2000 and num < 2010:
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword="")
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
_abbreviations = {
"en": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
],
"es": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("sra", "señora"),
("sr", "señor"),
("dr", "doctor"),
("dra", "doctora"),
("st", "santo"),
("co", "compañía"),
("jr", "junior"),
("ltd", "limitada"),
]
],
"fr": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mme", "madame"),
("mr", "monsieur"),
("dr", "docteur"),
("st", "saint"),
("co", "compagnie"),
("jr", "junior"),
("ltd", "limitée"),
]
],
"de": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("fr", "frau"),
("dr", "doktor"),
("st", "sankt"),
("co", "firma"),
("jr", "junior"),
]
],
"pt": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("sra", "senhora"),
("sr", "senhor"),
("dr", "doutor"),
("dra", "doutora"),
("st", "santo"),
("co", "companhia"),
("jr", "júnior"),
("ltd", "limitada"),
]
],
"it": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
#("sig.ra", "signora"),
("sig", "signore"),
("dr", "dottore"),
("st", "santo"),
("co", "compagnia"),
("jr", "junior"),
("ltd", "limitata"),
]
],
"pl": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("p", "pani"),
("m", "pan"),
("dr", "doktor"),
("sw", "święty"),
("jr", "junior"),
]
],
"ar": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
# There are not many common abbreviations in Arabic as in English.
]
],
"zh-cn": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
]
],
"cs": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("dr", "doktor"), # doctor
("ing", "inženýr"), # engineer
("p", "pan"), # Could also map to pani for woman but no easy way to do it
# Other abbreviations would be specialized and not as common.
]
],
"ru": [
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
for x in [
("г-жа", "госпожа"), # Mrs.
("г", "господин"), # Mr.
("д-р", "доктор"), # doctor
# Other abbreviations are less common or specialized.
]
],
"nl": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("dhr", "de heer"), # Mr.
("mevr", "mevrouw"), # Mrs.
("dr", "dokter"), # doctor
("jhr", "jonkheer"), # young lord or nobleman
# Dutch uses more abbreviations, but these are the most common ones.
]
],
"tr": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("b", "bay"), # Mr.
("byk", "büyük"), # büyük
("dr", "doktor"), # doctor
# Add other Turkish abbreviations here if needed.
]
],
}
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
def expand_abbreviations_multilingual(text, lang='en'):
for regex, replacement in _abbreviations[lang]:
text = re.sub(regex, replacement, text)
return text
_symbols_multilingual = {
'en': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " and "),
("@", " at "),
("%", " percent "),
("#", " hash "),
("$", " dollar "),
("£", " pound "),
("°", " degree ")
]
],
'es': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " y "),
("@", " arroba "),
("%", " por ciento "),
("#", " numeral "),
("$", " dolar "),
("£", " libra "),
("°", " grados ")
]
],
'fr': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " et "),
("@", " arobase "),
("%", " pour cent "),
("#", " dièse "),
("$", " dollar "),
("£", " livre "),
("°", " degrés ")
]
],
'de': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " und "),
("@", " at "),
("%", " prozent "),
("#", " raute "),
("$", " dollar "),
("£", " pfund "),
("°", " grad ")
]
],
'pt': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " e "),
("@", " arroba "),
("%", " por cento "),
("#", " cardinal "),
("$", " dólar "),
("£", " libra "),
("°", " graus ")
]
],
'it': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " e "),
("@", " chiocciola "),
("%", " per cento "),
("#", " cancelletto "),
("$", " dollaro "),
("£", " sterlina "),
("°", " gradi ")
]
],
'pl': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " i "),
("@", " małpa "),
("%", " procent "),
("#", " krzyżyk "),
("$", " dolar "),
("£", " funt "),
("°", " stopnie ")
]
],
"ar": [
# Arabic
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " و "),
("@", " على "),
("%", " في المئة "),
("#", " رقم "),
("$", " دولار "),
("£", " جنيه "),
("°", " درجة ")
]
],
"zh-cn": [
# Chinese
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", ""),
("@", ""),
("%", " 百分之 "),
("#", ""),
("$", " 美元 "),
("£", " 英镑 "),
("°", "")
]
],
"cs": [
# Czech
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " a "),
("@", " na "),
("%", " procento "),
("#", " křížek "),
("$", " dolar "),
("£", " libra "),
("°", " stupně ")
]
],
"ru": [
# Russian
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " и "),
("@", " собака "),
("%", " процентов "),
("#", " номер "),
("$", " доллар "),
("£", " фунт "),
("°", " градус ")
]
],
"nl": [
# Dutch
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " en "),
("@", " bij "),
("%", " procent "),
("#", " hekje "),
("$", " dollar "),
("£", " pond "),
("°", " graden ")
]
],
"tr": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " ve "),
("@", " at "),
("%", " yüzde "),
("#", " diyez "),
("$", " dolar "),
("£", " sterlin "),
("°", " derece ")
]
],
}
def expand_numbers(text):
return normalize_numbers(text)
def expand_symbols_multilingual(text, lang='en'):
for regex, replacement in _symbols_multilingual[lang]:
text = re.sub(regex, replacement, text)
text = text.replace(' ', ' ') # Ensure there are no double spaces
return text.strip()
_ordinal_re = {
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
"es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
"fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
"de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
"pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
"it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
"pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
"ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
"cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
"ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
"nl": re.compile(r"([0-9]+)(de|ste|e)"),
"tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
}
_number_re = re.compile(r"[0-9]+")
_currency_re = {
'USD': re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
'GBP': re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
'EUR': re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))")
}
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
def _remove_commas(m):
text = m.group(0)
if "," in text:
text = text.replace(",", "")
return text
def _remove_dots(m):
text = m.group(0)
if "." in text:
text = text.replace(".", "")
return text
def _expand_decimal_point(m, lang='en'):
amount = m.group(1).replace(",", ".")
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
def _expand_currency(m, lang='en', currency='USD'):
amount = float((re.sub(r'[^\d.]', '', m.group(0).replace(",", "."))))
full_amount = num2words(amount, to='currency', currency=currency, lang=lang if lang != "cs" else "cz")
and_equivalents = {
"en": ", ",
"es": " con ",
"fr": " et ",
"de": " und ",
"pt": " e ",
"it": " e ",
"pl": ", ",
"cs": ", ",
"ru": ", ",
"nl": ", ",
"ar": ", ",
"tr": ", ",
}
if amount.is_integer():
last_and = full_amount.rfind(and_equivalents[lang])
if last_and != -1:
full_amount = full_amount[:last_and]
return full_amount
def _expand_ordinal(m, lang='en'):
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
def _expand_number(m, lang='en'):
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
def expand_numbers_multilingual(text, lang='en'):
if lang == "zh-cn":
text = zh_num2words()(text)
else:
if lang in ["en", "ru"]:
text = re.sub(_comma_number_re, _remove_commas, text)
else:
text = re.sub(_dot_number_re, _remove_dots, text)
try:
text = re.sub(_currency_re['GBP'], lambda m: _expand_currency(m, lang, 'GBP'), text)
text = re.sub(_currency_re['USD'], lambda m: _expand_currency(m, lang, 'USD'), text)
text = re.sub(_currency_re['EUR'], lambda m: _expand_currency(m, lang, 'EUR'), text)
except:
pass
if lang != "tr":
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def convert_to_ascii(text):
return unidecode(text)
def multilingual_cleaners(text, lang):
text = text.replace('"', '')
if lang=="tr":
text = text.replace("İ", "i")
text = text.replace("Ö", "ö")
text = text.replace("Ü", "ü")
text = lowercase(text)
text = expand_numbers_multilingual(text, lang)
text = expand_abbreviations_multilingual(text, lang)
text = expand_symbols_multilingual(text, lang=lang)
text = collapse_whitespace(text)
return text
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
text = text.replace('"', "")
return text
def chinese_transliterate(text):
return "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)])
def expand_numbers_multilang(text, lang):
# TODO: Handle text more carefully. Currently, it just converts numbers without any context.
# Find all numbers in the input string
numbers = re.findall(r"\d+", text)
# Transliterate the numbers to text
for num in numbers:
transliterated_num = "".join(num2words(num, lang=lang))
text = text.replace(num, transliterated_num, 1)
return text
def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
def japanese_cleaners(text, katsu):
text = katsu.romaji(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def multilingual_cleaners(text, lang):
text = lowercase(text)
text = expand_numbers_multilang(text, lang)
text = collapse_whitespace(text)
text = text.replace('"', "")
if lang == "tr":
text = text.replace("İ", "i")
text = text.replace("Ö", "ö")
text = text.replace("Ü", "ü")
return text
def remove_extraneous_punctuation(word):
replacement_punctuation = {"{": "(", "}": ")", "[": "(", "]": ")", "`": "'", "": "-", "": "-", "`": "'", "ʼ": "'"}
replace = re.compile(
"|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL
)
word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
# TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
extraneous = re.compile(r"^[@#%_=\$\^&\*\+\\]$")
word = extraneous.sub("", word)
return word
def arabic_cleaners(text):
text = lowercase(text)
text = collapse_whitespace(text)
return text
def chinese_cleaners(text):
text = lowercase(text)
text = "".join(
[p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
)
return text
class VoiceBpeTokenizer:
def __init__(self, vocab_file=None, preprocess=None):
self.tokenizer = None
self.katsu = None
if vocab_file is not None:
with open(vocab_file, "r", encoding="utf-8") as f:
@ -216,24 +479,17 @@ class VoiceBpeTokenizer:
self.tokenizer = Tokenizer.from_file(vocab_file)
def preprocess_text(self, txt, lang):
if lang == "ja":
import pykakasi
kks = pykakasi.kakasi()
results = kks.convert(txt)
txt = " ".join([result["kana"] for result in results])
txt = basic_cleaners(txt)
elif lang == "en":
if txt[:4] == "[en]":
txt = txt[4:]
txt = english_cleaners(txt)
txt = "[en]" + txt
elif lang == "ar":
txt = arabic_cleaners(txt)
elif lang == "zh-cn":
txt = chinese_cleaners(txt)
else:
if lang in ["en", "es", "fr", "de", "pt", "it", "pl", "ar", "cs", "ru", "nl", "tr", "zh-cn"]:
txt = multilingual_cleaners(txt, lang)
if lang == "zh-cn":
txt = chinese_transliterate(txt)
elif lang == "ja":
if self.katsu is None:
import cutlet
self.katsu = cutlet.Cutlet()
txt = japanese_cleaners(txt, self.katsu)
else:
raise NotImplementedError()
return txt
def encode(self, txt, lang):
@ -250,3 +506,9 @@ class VoiceBpeTokenizer:
txt = txt.replace("[STOP]", "")
txt = txt.replace("[UNK]", "")
return txt
def __len__(self):
return self.tokenizer.get_vocab_size()
def get_number_tokens(self):
return max(self.tokenizer.get_vocab().values()) + 1

File diff suppressed because it is too large Load Diff

View File

@ -239,6 +239,7 @@ class XttsArgs(Coqpit):
decoder_checkpoint: str = None
num_chars: int = 255
use_hifigan: bool = True
use_ne_hifigan: bool = False
# XTTS GPT Encoder params
tokenizer_file: str = ""
@ -311,7 +312,7 @@ class Xtts(BaseTTS):
def init_models(self):
"""Initialize the models. We do it here since we need to load the tokenizer first."""
if self.tokenizer.tokenizer is not None:
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
self.args.gpt_number_text_tokens = self.tokenizer.get_number_tokens()
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
@ -343,7 +344,18 @@ class Xtts(BaseTTS):
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
else:
if self.args.use_ne_hifigan:
self.ne_hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate,
output_sample_rate=self.args.output_sample_rate,
output_hop_length=self.args.output_hop_length,
ar_mel_length_compression=self.args.ar_mel_length_compression,
decoder_input_dim=self.args.decoder_input_dim,
d_vector_dim=self.args.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
if not (self.args.use_hifigan or self.args.use_ne_hifigan):
self.diffusion_decoder = DiffusionTts(
model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers,
@ -491,6 +503,7 @@ class Xtts(BaseTTS):
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
decoder="hifigan",
**hf_generate_kwargs,
):
"""
@ -539,6 +552,9 @@ class Xtts(BaseTTS):
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to 1.0.
decoder: (str) Selects the decoder to use between ("hifigan", "ne_hifigan" and "diffusion")
Defaults to hifigan
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
here: https://huggingface.co/docs/transformers/internal/generation_utils
@ -569,6 +585,7 @@ class Xtts(BaseTTS):
cond_free_k=cond_free_k,
diffusion_temperature=diffusion_temperature,
decoder_sampler=decoder_sampler,
decoder=decoder,
**hf_generate_kwargs,
)
@ -593,6 +610,7 @@ class Xtts(BaseTTS):
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
decoder="hifigan",
**hf_generate_kwargs,
):
text = f"[{language}]{text.strip().lower()}"
@ -649,9 +667,14 @@ class Xtts(BaseTTS):
gpt_latents = gpt_latents[:, :k]
break
if self.args.use_hifigan:
if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
elif decoder == "ne_hifigan":
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
else:
assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
mel = do_spectrogram_diffusion(
self.diffusion_decoder,
diffuser,
@ -695,6 +718,7 @@ class Xtts(BaseTTS):
top_p=0.85,
do_sample=True,
# Decoder inference
decoder="hifigan",
**hf_generate_kwargs,
):
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
@ -736,7 +760,14 @@ class Xtts(BaseTTS):
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
elif decoder == "ne_hifigan":
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
else:
raise NotImplementedError("Diffusion for streaming inference not implemented.")
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
)
@ -794,7 +825,9 @@ class Xtts(BaseTTS):
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
for key in list(checkpoint.keys()):
if key.split(".")[0] in ignore_keys:
del checkpoint[key]
@ -802,6 +835,7 @@ class Xtts(BaseTTS):
if eval:
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval()
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
if hasattr(self, "vocoder"): self.vocoder.eval()
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)

View File

@ -294,8 +294,9 @@ class ModelManager(object):
# get model from models.json
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
model_item = self.set_model_url(model_item)
return model_item, model_full_name, model
return model_item, model_full_name, model, md5hash
def ask_tos(self, model_full_path):
"""Ask the user to agree to the terms of service"""
@ -358,8 +359,6 @@ class ModelManager(object):
if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
def download_model(self, model_name):
"""Download model files given the full model name.
@ -375,10 +374,22 @@ class ModelManager(object):
Args:
model_name (str): model name as explained above.
"""
model_item, model_full_name, model = self._set_model_item(model_name)
model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
if md5sum is not None:
md5sum_file = os.path.join(output_path, "hash.md5")
if os.path.isfile(md5sum_file):
with open(md5sum_file, mode="r") as f:
if not f.read() == md5sum:
print(f" > {model_name} has been updated, clearing model cache...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
else:
print(f" > {model_name} has been updated, clearing model cache...")
self.create_dir_and_download_model(model_name, model_item, output_path)
# if the configs are different, redownload it
# ToDo: we need a better way to handle it
if "xtts_v1" in model_name:
@ -395,7 +406,7 @@ class ModelManager(object):
output_model_path = output_path
output_config_path = None
if (
model not in ["tortoise-v2", "bark", "xtts_v1"] and "fairseq" not in model_name
model not in ["tortoise-v2", "bark", "xtts_v1", "xtts_v1.1"] and "fairseq" not in model_name
): # TODO:This is stupid but don't care for now.
output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json

View File

@ -2,3 +2,4 @@
# japanese g2p deps
mecab-python3==1.0.6
unidic-lite==1.0.8
cutlet

View File

@ -15,6 +15,7 @@ MODELS_WITH_SEP_TESTS = [
"tts_models/multilingual/multi-dataset/bark",
"tts_models/en/multi-dataset/tortoise-v2",
"tts_models/multilingual/multi-dataset/xtts_v1",
"tts_models/multilingual/multi-dataset/xtts_v1.1",
]
@ -93,6 +94,7 @@ def test_xtts():
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
)
def test_xtts_streaming():
"""Testing the new inference_stream method"""
from TTS.tts.configs.xtts_config import XttsConfig
@ -122,6 +124,7 @@ def test_xtts_streaming():
wav_chuncks.append(chunk)
assert len(wav_chuncks) > 1
def test_tortoise():
output_path = os.path.join(get_tests_output_path(), "output.wav")
use_gpu = torch.cuda.is_available()