* Add support for ne_hifigan

* Update model.json

* Update hash

* Fix model loading

* Enhance text_normalization

* Add xtts to zoo test exception

* Add model hash check

* Add get_number_tokens
This commit is contained in:
Julian Weber 2023-10-20 16:02:08 +02:00 committed by GitHub
parent 747f688dc3
commit cf97116185
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1726 additions and 193 deletions

View File

@ -15,6 +15,21 @@
"contact": "info@coqui.ai", "contact": "info@coqui.ai",
"tos_required": true "tos_required": true
}, },
"xtts_v1.1": {
"description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/hash.md5"
],
"model_hash": "10163afc541dc86801b33d1f3217b456",
"default_vocoder": null,
"commit": "82910a63",
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": true
},
"your_tts": { "your_tts": {
"description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418", "description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--multilingual--multi-dataset--your_tts.zip", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--multilingual--multi-dataset--your_tts.zip",

View File

@ -1,206 +1,469 @@
import json
import os import os
import re import re
import json
import inflect
import pandas as pd
import pypinyin
import torch import torch
from num2words import num2words
from tokenizers import Tokenizer from tokenizers import Tokenizer
from unidecode import unidecode
from TTS.tts.utils.text.cleaners import english_cleaners import pypinyin
import cutlet
from num2words import num2words
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
def _remove_commas(m):
return m.group(1).replace(",", "")
def _expand_decimal_point(m):
return m.group(1).replace(".", " point ")
def _expand_dollars(m):
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return "zero dollars"
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return "two thousand"
elif num > 2000 and num < 2010:
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword="")
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+") _whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations: # List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [ _abbreviations = {
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) "en": [
for x in [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
("mrs", "misess"), for x in [
("mr", "mister"), ("mrs", "misess"),
("dr", "doctor"), ("mr", "mister"),
("st", "saint"), ("dr", "doctor"),
("co", "company"), ("st", "saint"),
("jr", "junior"), ("co", "company"),
("maj", "major"), ("jr", "junior"),
("gen", "general"), ("maj", "major"),
("drs", "doctors"), ("gen", "general"),
("rev", "reverend"), ("drs", "doctors"),
("lt", "lieutenant"), ("rev", "reverend"),
("hon", "honorable"), ("lt", "lieutenant"),
("sgt", "sergeant"), ("hon", "honorable"),
("capt", "captain"), ("sgt", "sergeant"),
("esq", "esquire"), ("capt", "captain"),
("ltd", "limited"), ("esq", "esquire"),
("col", "colonel"), ("ltd", "limited"),
("ft", "fort"), ("col", "colonel"),
] ("ft", "fort"),
] ]
],
"es": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("sra", "señora"),
("sr", "señor"),
("dr", "doctor"),
("dra", "doctora"),
("st", "santo"),
("co", "compañía"),
("jr", "junior"),
("ltd", "limitada"),
]
],
"fr": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mme", "madame"),
("mr", "monsieur"),
("dr", "docteur"),
("st", "saint"),
("co", "compagnie"),
("jr", "junior"),
("ltd", "limitée"),
]
],
"de": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("fr", "frau"),
("dr", "doktor"),
("st", "sankt"),
("co", "firma"),
("jr", "junior"),
]
],
"pt": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("sra", "senhora"),
("sr", "senhor"),
("dr", "doutor"),
("dra", "doutora"),
("st", "santo"),
("co", "companhia"),
("jr", "júnior"),
("ltd", "limitada"),
]
],
"it": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
#("sig.ra", "signora"),
("sig", "signore"),
("dr", "dottore"),
("st", "santo"),
("co", "compagnia"),
("jr", "junior"),
("ltd", "limitata"),
]
],
"pl": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("p", "pani"),
("m", "pan"),
("dr", "doktor"),
("sw", "święty"),
("jr", "junior"),
]
],
"ar": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
# There are not many common abbreviations in Arabic as in English.
]
],
"zh-cn": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
]
],
"cs": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("dr", "doktor"), # doctor
("ing", "inženýr"), # engineer
("p", "pan"), # Could also map to pani for woman but no easy way to do it
# Other abbreviations would be specialized and not as common.
]
],
"ru": [
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
for x in [
("г-жа", "госпожа"), # Mrs.
("г", "господин"), # Mr.
("д-р", "доктор"), # doctor
# Other abbreviations are less common or specialized.
]
],
"nl": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("dhr", "de heer"), # Mr.
("mevr", "mevrouw"), # Mrs.
("dr", "dokter"), # doctor
("jhr", "jonkheer"), # young lord or nobleman
# Dutch uses more abbreviations, but these are the most common ones.
]
],
"tr": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("b", "bay"), # Mr.
("byk", "büyük"), # büyük
("dr", "doktor"), # doctor
# Add other Turkish abbreviations here if needed.
]
],
}
def expand_abbreviations_multilingual(text, lang='en'):
def expand_abbreviations(text): for regex, replacement in _abbreviations[lang]:
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
_symbols_multilingual = {
'en': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " and "),
("@", " at "),
("%", " percent "),
("#", " hash "),
("$", " dollar "),
("£", " pound "),
("°", " degree ")
]
],
'es': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " y "),
("@", " arroba "),
("%", " por ciento "),
("#", " numeral "),
("$", " dolar "),
("£", " libra "),
("°", " grados ")
]
],
'fr': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " et "),
("@", " arobase "),
("%", " pour cent "),
("#", " dièse "),
("$", " dollar "),
("£", " livre "),
("°", " degrés ")
]
],
'de': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " und "),
("@", " at "),
("%", " prozent "),
("#", " raute "),
("$", " dollar "),
("£", " pfund "),
("°", " grad ")
]
],
'pt': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " e "),
("@", " arroba "),
("%", " por cento "),
("#", " cardinal "),
("$", " dólar "),
("£", " libra "),
("°", " graus ")
]
],
'it': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " e "),
("@", " chiocciola "),
("%", " per cento "),
("#", " cancelletto "),
("$", " dollaro "),
("£", " sterlina "),
("°", " gradi ")
]
],
'pl': [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " i "),
("@", " małpa "),
("%", " procent "),
("#", " krzyżyk "),
("$", " dolar "),
("£", " funt "),
("°", " stopnie ")
]
],
"ar": [
# Arabic
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " و "),
("@", " على "),
("%", " في المئة "),
("#", " رقم "),
("$", " دولار "),
("£", " جنيه "),
("°", " درجة ")
]
],
"zh-cn": [
# Chinese
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", ""),
("@", ""),
("%", " 百分之 "),
("#", ""),
("$", " 美元 "),
("£", " 英镑 "),
("°", "")
]
],
"cs": [
# Czech
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " a "),
("@", " na "),
("%", " procento "),
("#", " křížek "),
("$", " dolar "),
("£", " libra "),
("°", " stupně ")
]
],
"ru": [
# Russian
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " и "),
("@", " собака "),
("%", " процентов "),
("#", " номер "),
("$", " доллар "),
("£", " фунт "),
("°", " градус ")
]
],
"nl": [
# Dutch
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " en "),
("@", " bij "),
("%", " procent "),
("#", " hekje "),
("$", " dollar "),
("£", " pond "),
("°", " graden ")
]
],
"tr": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " ve "),
("@", " at "),
("%", " yüzde "),
("#", " diyez "),
("$", " dolar "),
("£", " sterlin "),
("°", " derece ")
]
],
}
def expand_numbers(text): def expand_symbols_multilingual(text, lang='en'):
return normalize_numbers(text) for regex, replacement in _symbols_multilingual[lang]:
text = re.sub(regex, replacement, text)
text = text.replace(' ', ' ') # Ensure there are no double spaces
return text.strip()
_ordinal_re = {
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
"es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
"fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
"de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
"pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
"it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
"pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
"ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
"cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
"ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
"nl": re.compile(r"([0-9]+)(de|ste|e)"),
"tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
}
_number_re = re.compile(r"[0-9]+")
_currency_re = {
'USD': re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
'GBP': re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
'EUR': re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))")
}
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
def _remove_commas(m):
text = m.group(0)
if "," in text:
text = text.replace(",", "")
return text
def _remove_dots(m):
text = m.group(0)
if "." in text:
text = text.replace(".", "")
return text
def _expand_decimal_point(m, lang='en'):
amount = m.group(1).replace(",", ".")
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
def _expand_currency(m, lang='en', currency='USD'):
amount = float((re.sub(r'[^\d.]', '', m.group(0).replace(",", "."))))
full_amount = num2words(amount, to='currency', currency=currency, lang=lang if lang != "cs" else "cz")
and_equivalents = {
"en": ", ",
"es": " con ",
"fr": " et ",
"de": " und ",
"pt": " e ",
"it": " e ",
"pl": ", ",
"cs": ", ",
"ru": ", ",
"nl": ", ",
"ar": ", ",
"tr": ", ",
}
if amount.is_integer():
last_and = full_amount.rfind(and_equivalents[lang])
if last_and != -1:
full_amount = full_amount[:last_and]
return full_amount
def _expand_ordinal(m, lang='en'):
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
def _expand_number(m, lang='en'):
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
def expand_numbers_multilingual(text, lang='en'):
if lang == "zh-cn":
text = zh_num2words()(text)
else:
if lang in ["en", "ru"]:
text = re.sub(_comma_number_re, _remove_commas, text)
else:
text = re.sub(_dot_number_re, _remove_dots, text)
try:
text = re.sub(_currency_re['GBP'], lambda m: _expand_currency(m, lang, 'GBP'), text)
text = re.sub(_currency_re['USD'], lambda m: _expand_currency(m, lang, 'USD'), text)
text = re.sub(_currency_re['EUR'], lambda m: _expand_currency(m, lang, 'EUR'), text)
except:
pass
if lang != "tr":
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
return text
def lowercase(text): def lowercase(text):
return text.lower() return text.lower()
def collapse_whitespace(text): def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text) return re.sub(_whitespace_re, " ", text)
def multilingual_cleaners(text, lang):
def convert_to_ascii(text): text = text.replace('"', '')
return unidecode(text) if lang=="tr":
text = text.replace("İ", "i")
text = text.replace("Ö", "ö")
text = text.replace("Ü", "ü")
text = lowercase(text)
text = expand_numbers_multilingual(text, lang)
text = expand_abbreviations_multilingual(text, lang)
text = expand_symbols_multilingual(text, lang=lang)
text = collapse_whitespace(text)
return text
def basic_cleaners(text): def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration.""" """Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text) text = lowercase(text)
text = collapse_whitespace(text) text = collapse_whitespace(text)
text = text.replace('"', "")
return text return text
def chinese_transliterate(text):
return "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)])
def expand_numbers_multilang(text, lang): def japanese_cleaners(text, katsu):
# TODO: Handle text more carefully. Currently, it just converts numbers without any context. text = katsu.romaji(text)
# Find all numbers in the input string
numbers = re.findall(r"\d+", text)
# Transliterate the numbers to text
for num in numbers:
transliterated_num = "".join(num2words(num, lang=lang))
text = text.replace(num, transliterated_num, 1)
return text
def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
text = lowercase(text) text = lowercase(text)
text = collapse_whitespace(text)
return text return text
def multilingual_cleaners(text, lang):
text = lowercase(text)
text = expand_numbers_multilang(text, lang)
text = collapse_whitespace(text)
text = text.replace('"', "")
if lang == "tr":
text = text.replace("İ", "i")
text = text.replace("Ö", "ö")
text = text.replace("Ü", "ü")
return text
def remove_extraneous_punctuation(word):
replacement_punctuation = {"{": "(", "}": ")", "[": "(", "]": ")", "`": "'", "": "-", "": "-", "`": "'", "ʼ": "'"}
replace = re.compile(
"|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL
)
word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
# TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
extraneous = re.compile(r"^[@#%_=\$\^&\*\+\\]$")
word = extraneous.sub("", word)
return word
def arabic_cleaners(text):
text = lowercase(text)
text = collapse_whitespace(text)
return text
def chinese_cleaners(text):
text = lowercase(text)
text = "".join(
[p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
)
return text
class VoiceBpeTokenizer: class VoiceBpeTokenizer:
def __init__(self, vocab_file=None, preprocess=None): def __init__(self, vocab_file=None, preprocess=None):
self.tokenizer = None self.tokenizer = None
self.katsu = None
if vocab_file is not None: if vocab_file is not None:
with open(vocab_file, "r", encoding="utf-8") as f: with open(vocab_file, "r", encoding="utf-8") as f:
@ -216,24 +479,17 @@ class VoiceBpeTokenizer:
self.tokenizer = Tokenizer.from_file(vocab_file) self.tokenizer = Tokenizer.from_file(vocab_file)
def preprocess_text(self, txt, lang): def preprocess_text(self, txt, lang):
if lang == "ja": if lang in ["en", "es", "fr", "de", "pt", "it", "pl", "ar", "cs", "ru", "nl", "tr", "zh-cn"]:
import pykakasi
kks = pykakasi.kakasi()
results = kks.convert(txt)
txt = " ".join([result["kana"] for result in results])
txt = basic_cleaners(txt)
elif lang == "en":
if txt[:4] == "[en]":
txt = txt[4:]
txt = english_cleaners(txt)
txt = "[en]" + txt
elif lang == "ar":
txt = arabic_cleaners(txt)
elif lang == "zh-cn":
txt = chinese_cleaners(txt)
else:
txt = multilingual_cleaners(txt, lang) txt = multilingual_cleaners(txt, lang)
if lang == "zh-cn":
txt = chinese_transliterate(txt)
elif lang == "ja":
if self.katsu is None:
import cutlet
self.katsu = cutlet.Cutlet()
txt = japanese_cleaners(txt, self.katsu)
else:
raise NotImplementedError()
return txt return txt
def encode(self, txt, lang): def encode(self, txt, lang):
@ -250,3 +506,9 @@ class VoiceBpeTokenizer:
txt = txt.replace("[STOP]", "") txt = txt.replace("[STOP]", "")
txt = txt.replace("[UNK]", "") txt = txt.replace("[UNK]", "")
return txt return txt
def __len__(self):
return self.tokenizer.get_vocab_size()
def get_number_tokens(self):
return max(self.tokenizer.get_vocab().values()) + 1

File diff suppressed because it is too large Load Diff

View File

@ -239,6 +239,7 @@ class XttsArgs(Coqpit):
decoder_checkpoint: str = None decoder_checkpoint: str = None
num_chars: int = 255 num_chars: int = 255
use_hifigan: bool = True use_hifigan: bool = True
use_ne_hifigan: bool = False
# XTTS GPT Encoder params # XTTS GPT Encoder params
tokenizer_file: str = "" tokenizer_file: str = ""
@ -311,7 +312,7 @@ class Xtts(BaseTTS):
def init_models(self): def init_models(self):
"""Initialize the models. We do it here since we need to load the tokenizer first.""" """Initialize the models. We do it here since we need to load the tokenizer first."""
if self.tokenizer.tokenizer is not None: if self.tokenizer.tokenizer is not None:
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size() self.args.gpt_number_text_tokens = self.tokenizer.get_number_tokens()
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]") self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]") self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
@ -343,7 +344,18 @@ class Xtts(BaseTTS):
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
) )
else: if self.args.use_ne_hifigan:
self.ne_hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate,
output_sample_rate=self.args.output_sample_rate,
output_hop_length=self.args.output_hop_length,
ar_mel_length_compression=self.args.ar_mel_length_compression,
decoder_input_dim=self.args.decoder_input_dim,
d_vector_dim=self.args.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
if not (self.args.use_hifigan or self.args.use_ne_hifigan):
self.diffusion_decoder = DiffusionTts( self.diffusion_decoder = DiffusionTts(
model_channels=self.args.diff_model_channels, model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers, num_layers=self.args.diff_num_layers,
@ -491,6 +503,7 @@ class Xtts(BaseTTS):
cond_free_k=2, cond_free_k=2,
diffusion_temperature=1.0, diffusion_temperature=1.0,
decoder_sampler="ddim", decoder_sampler="ddim",
decoder="hifigan",
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
""" """
@ -539,6 +552,9 @@ class Xtts(BaseTTS):
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared. Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to 1.0. Defaults to 1.0.
decoder: (str) Selects the decoder to use between ("hifigan", "ne_hifigan" and "diffusion")
Defaults to hifigan
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
here: https://huggingface.co/docs/transformers/internal/generation_utils here: https://huggingface.co/docs/transformers/internal/generation_utils
@ -569,6 +585,7 @@ class Xtts(BaseTTS):
cond_free_k=cond_free_k, cond_free_k=cond_free_k,
diffusion_temperature=diffusion_temperature, diffusion_temperature=diffusion_temperature,
decoder_sampler=decoder_sampler, decoder_sampler=decoder_sampler,
decoder=decoder,
**hf_generate_kwargs, **hf_generate_kwargs,
) )
@ -593,6 +610,7 @@ class Xtts(BaseTTS):
cond_free_k=2, cond_free_k=2,
diffusion_temperature=1.0, diffusion_temperature=1.0,
decoder_sampler="ddim", decoder_sampler="ddim",
decoder="hifigan",
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
text = f"[{language}]{text.strip().lower()}" text = f"[{language}]{text.strip().lower()}"
@ -649,9 +667,14 @@ class Xtts(BaseTTS):
gpt_latents = gpt_latents[:, :k] gpt_latents = gpt_latents[:, :k]
break break
if self.args.use_hifigan: if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
elif decoder == "ne_hifigan":
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
else: else:
assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
mel = do_spectrogram_diffusion( mel = do_spectrogram_diffusion(
self.diffusion_decoder, self.diffusion_decoder,
diffuser, diffuser,
@ -695,6 +718,7 @@ class Xtts(BaseTTS):
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
# Decoder inference # Decoder inference
decoder="hifigan",
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
@ -736,7 +760,14 @@ class Xtts(BaseTTS):
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :] gpt_latents = torch.cat(all_latents, dim=0)[None, :]
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
elif decoder == "ne_hifigan":
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
else:
raise NotImplementedError("Diffusion for streaming inference not implemented.")
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
) )
@ -794,7 +825,9 @@ class Xtts(BaseTTS):
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"] ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
for key in list(checkpoint.keys()): for key in list(checkpoint.keys()):
if key.split(".")[0] in ignore_keys: if key.split(".")[0] in ignore_keys:
del checkpoint[key] del checkpoint[key]
@ -802,6 +835,7 @@ class Xtts(BaseTTS):
if eval: if eval:
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval() if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval()
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval() if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
if hasattr(self, "vocoder"): self.vocoder.eval() if hasattr(self, "vocoder"): self.vocoder.eval()
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)

View File

@ -294,8 +294,9 @@ class ModelManager(object):
# get model from models.json # get model from models.json
model_item = self.models_dict[model_type][lang][dataset][model] model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type model_item["model_type"] = model_type
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
model_item = self.set_model_url(model_item) model_item = self.set_model_url(model_item)
return model_item, model_full_name, model return model_item, model_full_name, model, md5hash
def ask_tos(self, model_full_path): def ask_tos(self, model_full_path):
"""Ask the user to agree to the terms of service""" """Ask the user to agree to the terms of service"""
@ -358,8 +359,6 @@ class ModelManager(object):
if not config_local == config_remote: if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path) self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
def download_model(self, model_name): def download_model(self, model_name):
"""Download model files given the full model name. """Download model files given the full model name.
@ -375,10 +374,22 @@ class ModelManager(object):
Args: Args:
model_name (str): model name as explained above. model_name (str): model name as explained above.
""" """
model_item, model_full_name, model = self._set_model_item(model_name) model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
# set the model specific output path # set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name) output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path): if os.path.exists(output_path):
if md5sum is not None:
md5sum_file = os.path.join(output_path, "hash.md5")
if os.path.isfile(md5sum_file):
with open(md5sum_file, mode="r") as f:
if not f.read() == md5sum:
print(f" > {model_name} has been updated, clearing model cache...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
else:
print(f" > {model_name} has been updated, clearing model cache...")
self.create_dir_and_download_model(model_name, model_item, output_path)
# if the configs are different, redownload it # if the configs are different, redownload it
# ToDo: we need a better way to handle it # ToDo: we need a better way to handle it
if "xtts_v1" in model_name: if "xtts_v1" in model_name:
@ -395,7 +406,7 @@ class ModelManager(object):
output_model_path = output_path output_model_path = output_path
output_config_path = None output_config_path = None
if ( if (
model not in ["tortoise-v2", "bark", "xtts_v1"] and "fairseq" not in model_name model not in ["tortoise-v2", "bark", "xtts_v1", "xtts_v1.1"] and "fairseq" not in model_name
): # TODO:This is stupid but don't care for now. ): # TODO:This is stupid but don't care for now.
output_model_path, output_config_path = self._find_files(output_path) output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json # update paths in the config.json

View File

@ -2,3 +2,4 @@
# japanese g2p deps # japanese g2p deps
mecab-python3==1.0.6 mecab-python3==1.0.6
unidic-lite==1.0.8 unidic-lite==1.0.8
cutlet

View File

@ -15,6 +15,7 @@ MODELS_WITH_SEP_TESTS = [
"tts_models/multilingual/multi-dataset/bark", "tts_models/multilingual/multi-dataset/bark",
"tts_models/en/multi-dataset/tortoise-v2", "tts_models/en/multi-dataset/tortoise-v2",
"tts_models/multilingual/multi-dataset/xtts_v1", "tts_models/multilingual/multi-dataset/xtts_v1",
"tts_models/multilingual/multi-dataset/xtts_v1.1",
] ]
@ -93,6 +94,7 @@ def test_xtts():
f'--speaker_wav "{speaker_wav}" --language_idx "en"' f'--speaker_wav "{speaker_wav}" --language_idx "en"'
) )
def test_xtts_streaming(): def test_xtts_streaming():
"""Testing the new inference_stream method""" """Testing the new inference_stream method"""
from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.configs.xtts_config import XttsConfig
@ -122,6 +124,7 @@ def test_xtts_streaming():
wav_chuncks.append(chunk) wav_chuncks.append(chunk)
assert len(wav_chuncks) > 1 assert len(wav_chuncks) > 1
def test_tortoise(): def test_tortoise():
output_path = os.path.join(get_tests_output_path(), "output.wav") output_path = os.path.join(get_tests_output_path(), "output.wav")
use_gpu = torch.cuda.is_available() use_gpu = torch.cuda.is_available()