mirror of https://github.com/coqui-ai/TTS.git
Run `make style` & re-enable it in CI
This commit is contained in:
parent
6fef4f9067
commit
4cfc3e5779
|
@ -42,6 +42,5 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install .[all]
|
python3 -m pip install .[all]
|
||||||
python3 setup.py egg_info
|
python3 setup.py egg_info
|
||||||
# - name: Lint check
|
- name: Style check
|
||||||
# run: |
|
run: make style
|
||||||
# make lint
|
|
||||||
|
|
|
@ -427,7 +427,9 @@ def main():
|
||||||
tts_path = model_path
|
tts_path = model_path
|
||||||
tts_config_path = config_path
|
tts_config_path = config_path
|
||||||
if "default_vocoder" in model_item:
|
if "default_vocoder" in model_item:
|
||||||
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
args.vocoder_name = (
|
||||||
|
model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||||
|
)
|
||||||
|
|
||||||
# voice conversion model
|
# voice conversion model
|
||||||
if model_item["model_type"] == "voice_conversion_models":
|
if model_item["model_type"] == "voice_conversion_models":
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import json
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from tokenizers import Tokenizer
|
|
||||||
|
|
||||||
import pypinyin
|
import pypinyin
|
||||||
|
import torch
|
||||||
from num2words import num2words
|
from num2words import num2words
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||||
|
|
||||||
_whitespace_re = re.compile(r"\s+")
|
_whitespace_re = re.compile(r"\s+")
|
||||||
|
@ -157,13 +157,15 @@ _abbreviations = {
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
def expand_abbreviations_multilingual(text, lang='en'):
|
|
||||||
|
def expand_abbreviations_multilingual(text, lang="en"):
|
||||||
for regex, replacement in _abbreviations[lang]:
|
for regex, replacement in _abbreviations[lang]:
|
||||||
text = re.sub(regex, replacement, text)
|
text = re.sub(regex, replacement, text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
_symbols_multilingual = {
|
_symbols_multilingual = {
|
||||||
'en': [
|
"en": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " and "),
|
("&", " and "),
|
||||||
|
@ -172,10 +174,10 @@ _symbols_multilingual = {
|
||||||
("#", " hash "),
|
("#", " hash "),
|
||||||
("$", " dollar "),
|
("$", " dollar "),
|
||||||
("£", " pound "),
|
("£", " pound "),
|
||||||
("°", " degree ")
|
("°", " degree "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'es': [
|
"es": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " y "),
|
("&", " y "),
|
||||||
|
@ -184,10 +186,10 @@ _symbols_multilingual = {
|
||||||
("#", " numeral "),
|
("#", " numeral "),
|
||||||
("$", " dolar "),
|
("$", " dolar "),
|
||||||
("£", " libra "),
|
("£", " libra "),
|
||||||
("°", " grados ")
|
("°", " grados "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'fr': [
|
"fr": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " et "),
|
("&", " et "),
|
||||||
|
@ -196,10 +198,10 @@ _symbols_multilingual = {
|
||||||
("#", " dièse "),
|
("#", " dièse "),
|
||||||
("$", " dollar "),
|
("$", " dollar "),
|
||||||
("£", " livre "),
|
("£", " livre "),
|
||||||
("°", " degrés ")
|
("°", " degrés "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'de': [
|
"de": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " und "),
|
("&", " und "),
|
||||||
|
@ -208,10 +210,10 @@ _symbols_multilingual = {
|
||||||
("#", " raute "),
|
("#", " raute "),
|
||||||
("$", " dollar "),
|
("$", " dollar "),
|
||||||
("£", " pfund "),
|
("£", " pfund "),
|
||||||
("°", " grad ")
|
("°", " grad "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'pt': [
|
"pt": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " e "),
|
("&", " e "),
|
||||||
|
@ -220,10 +222,10 @@ _symbols_multilingual = {
|
||||||
("#", " cardinal "),
|
("#", " cardinal "),
|
||||||
("$", " dólar "),
|
("$", " dólar "),
|
||||||
("£", " libra "),
|
("£", " libra "),
|
||||||
("°", " graus ")
|
("°", " graus "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'it': [
|
"it": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " e "),
|
("&", " e "),
|
||||||
|
@ -232,10 +234,10 @@ _symbols_multilingual = {
|
||||||
("#", " cancelletto "),
|
("#", " cancelletto "),
|
||||||
("$", " dollaro "),
|
("$", " dollaro "),
|
||||||
("£", " sterlina "),
|
("£", " sterlina "),
|
||||||
("°", " gradi ")
|
("°", " gradi "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
'pl': [
|
"pl": [
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
("&", " i "),
|
("&", " i "),
|
||||||
|
@ -244,7 +246,7 @@ _symbols_multilingual = {
|
||||||
("#", " krzyżyk "),
|
("#", " krzyżyk "),
|
||||||
("$", " dolar "),
|
("$", " dolar "),
|
||||||
("£", " funt "),
|
("£", " funt "),
|
||||||
("°", " stopnie ")
|
("°", " stopnie "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"ar": [
|
"ar": [
|
||||||
|
@ -257,7 +259,7 @@ _symbols_multilingual = {
|
||||||
("#", " رقم "),
|
("#", " رقم "),
|
||||||
("$", " دولار "),
|
("$", " دولار "),
|
||||||
("£", " جنيه "),
|
("£", " جنيه "),
|
||||||
("°", " درجة ")
|
("°", " درجة "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"zh-cn": [
|
"zh-cn": [
|
||||||
|
@ -270,7 +272,7 @@ _symbols_multilingual = {
|
||||||
("#", " 号 "),
|
("#", " 号 "),
|
||||||
("$", " 美元 "),
|
("$", " 美元 "),
|
||||||
("£", " 英镑 "),
|
("£", " 英镑 "),
|
||||||
("°", " 度 ")
|
("°", " 度 "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"cs": [
|
"cs": [
|
||||||
|
@ -283,7 +285,7 @@ _symbols_multilingual = {
|
||||||
("#", " křížek "),
|
("#", " křížek "),
|
||||||
("$", " dolar "),
|
("$", " dolar "),
|
||||||
("£", " libra "),
|
("£", " libra "),
|
||||||
("°", " stupně ")
|
("°", " stupně "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"ru": [
|
"ru": [
|
||||||
|
@ -296,7 +298,7 @@ _symbols_multilingual = {
|
||||||
("#", " номер "),
|
("#", " номер "),
|
||||||
("$", " доллар "),
|
("$", " доллар "),
|
||||||
("£", " фунт "),
|
("£", " фунт "),
|
||||||
("°", " градус ")
|
("°", " градус "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"nl": [
|
"nl": [
|
||||||
|
@ -309,7 +311,7 @@ _symbols_multilingual = {
|
||||||
("#", " hekje "),
|
("#", " hekje "),
|
||||||
("$", " dollar "),
|
("$", " dollar "),
|
||||||
("£", " pond "),
|
("£", " pond "),
|
||||||
("°", " graden ")
|
("°", " graden "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"tr": [
|
"tr": [
|
||||||
|
@ -321,15 +323,16 @@ _symbols_multilingual = {
|
||||||
("#", " diyez "),
|
("#", " diyez "),
|
||||||
("$", " dolar "),
|
("$", " dolar "),
|
||||||
("£", " sterlin "),
|
("£", " sterlin "),
|
||||||
("°", " derece ")
|
("°", " derece "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
def expand_symbols_multilingual(text, lang='en'):
|
|
||||||
|
def expand_symbols_multilingual(text, lang="en"):
|
||||||
for regex, replacement in _symbols_multilingual[lang]:
|
for regex, replacement in _symbols_multilingual[lang]:
|
||||||
text = re.sub(regex, replacement, text)
|
text = re.sub(regex, replacement, text)
|
||||||
text = text.replace(' ', ' ') # Ensure there are no double spaces
|
text = text.replace(" ", " ") # Ensure there are no double spaces
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
@ -349,34 +352,38 @@ _ordinal_re = {
|
||||||
}
|
}
|
||||||
_number_re = re.compile(r"[0-9]+")
|
_number_re = re.compile(r"[0-9]+")
|
||||||
_currency_re = {
|
_currency_re = {
|
||||||
'USD': re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
||||||
'GBP': re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
||||||
'EUR': re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))")
|
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
||||||
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
||||||
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
||||||
|
|
||||||
|
|
||||||
def _remove_commas(m):
|
def _remove_commas(m):
|
||||||
text = m.group(0)
|
text = m.group(0)
|
||||||
if "," in text:
|
if "," in text:
|
||||||
text = text.replace(",", "")
|
text = text.replace(",", "")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def _remove_dots(m):
|
def _remove_dots(m):
|
||||||
text = m.group(0)
|
text = m.group(0)
|
||||||
if "." in text:
|
if "." in text:
|
||||||
text = text.replace(".", "")
|
text = text.replace(".", "")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _expand_decimal_point(m, lang='en'):
|
|
||||||
|
def _expand_decimal_point(m, lang="en"):
|
||||||
amount = m.group(1).replace(",", ".")
|
amount = m.group(1).replace(",", ".")
|
||||||
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
|
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
|
||||||
|
|
||||||
def _expand_currency(m, lang='en', currency='USD'):
|
|
||||||
amount = float((re.sub(r'[^\d.]', '', m.group(0).replace(",", "."))))
|
def _expand_currency(m, lang="en", currency="USD"):
|
||||||
full_amount = num2words(amount, to='currency', currency=currency, lang=lang if lang != "cs" else "cz")
|
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
|
||||||
|
full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
|
||||||
|
|
||||||
and_equivalents = {
|
and_equivalents = {
|
||||||
"en": ", ",
|
"en": ", ",
|
||||||
|
@ -400,13 +407,16 @@ def _expand_currency(m, lang='en', currency='USD'):
|
||||||
|
|
||||||
return full_amount
|
return full_amount
|
||||||
|
|
||||||
def _expand_ordinal(m, lang='en'):
|
|
||||||
|
def _expand_ordinal(m, lang="en"):
|
||||||
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
|
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
|
||||||
|
|
||||||
def _expand_number(m, lang='en'):
|
|
||||||
|
def _expand_number(m, lang="en"):
|
||||||
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
|
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
|
||||||
|
|
||||||
def expand_numbers_multilingual(text, lang='en'):
|
|
||||||
|
def expand_numbers_multilingual(text, lang="en"):
|
||||||
if lang == "zh-cn":
|
if lang == "zh-cn":
|
||||||
text = zh_num2words()(text)
|
text = zh_num2words()(text)
|
||||||
else:
|
else:
|
||||||
|
@ -415,9 +425,9 @@ def expand_numbers_multilingual(text, lang='en'):
|
||||||
else:
|
else:
|
||||||
text = re.sub(_dot_number_re, _remove_dots, text)
|
text = re.sub(_dot_number_re, _remove_dots, text)
|
||||||
try:
|
try:
|
||||||
text = re.sub(_currency_re['GBP'], lambda m: _expand_currency(m, lang, 'GBP'), text)
|
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
|
||||||
text = re.sub(_currency_re['USD'], lambda m: _expand_currency(m, lang, 'USD'), text)
|
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
|
||||||
text = re.sub(_currency_re['EUR'], lambda m: _expand_currency(m, lang, 'EUR'), text)
|
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
if lang != "tr":
|
if lang != "tr":
|
||||||
|
@ -426,14 +436,17 @@ def expand_numbers_multilingual(text, lang='en'):
|
||||||
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def lowercase(text):
|
def lowercase(text):
|
||||||
return text.lower()
|
return text.lower()
|
||||||
|
|
||||||
|
|
||||||
def collapse_whitespace(text):
|
def collapse_whitespace(text):
|
||||||
return re.sub(_whitespace_re, " ", text)
|
return re.sub(_whitespace_re, " ", text)
|
||||||
|
|
||||||
|
|
||||||
def multilingual_cleaners(text, lang):
|
def multilingual_cleaners(text, lang):
|
||||||
text = text.replace('"', '')
|
text = text.replace('"', "")
|
||||||
if lang == "tr":
|
if lang == "tr":
|
||||||
text = text.replace("İ", "i")
|
text = text.replace("İ", "i")
|
||||||
text = text.replace("Ö", "ö")
|
text = text.replace("Ö", "ö")
|
||||||
|
@ -445,20 +458,26 @@ def multilingual_cleaners(text, lang):
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def basic_cleaners(text):
|
def basic_cleaners(text):
|
||||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def chinese_transliterate(text):
|
def chinese_transliterate(text):
|
||||||
return "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)])
|
return "".join(
|
||||||
|
p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def japanese_cleaners(text, katsu):
|
def japanese_cleaners(text, katsu):
|
||||||
text = katsu.romaji(text)
|
text = katsu.romaji(text)
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
class VoiceBpeTokenizer:
|
class VoiceBpeTokenizer:
|
||||||
def __init__(self, vocab_file=None, preprocess=None):
|
def __init__(self, vocab_file=None, preprocess=None):
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
@ -485,6 +504,7 @@ class VoiceBpeTokenizer:
|
||||||
elif lang == "ja":
|
elif lang == "ja":
|
||||||
if self.katsu is None:
|
if self.katsu is None:
|
||||||
import cutlet
|
import cutlet
|
||||||
|
|
||||||
self.katsu = cutlet.Cutlet()
|
self.katsu = cutlet.Cutlet()
|
||||||
txt = japanese_cleaners(txt, self.katsu)
|
txt = japanese_cleaners(txt, self.katsu)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -2,9 +2,14 @@
|
||||||
# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
|
# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
|
||||||
# 2019.9 - 2022 Jiayu DU
|
# 2019.9 - 2022 Jiayu DU
|
||||||
|
|
||||||
import sys, os, argparse
|
import argparse
|
||||||
import string, re
|
|
||||||
import csv
|
import csv
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
# ================================================================================ #
|
# ================================================================================ #
|
||||||
# basic constant
|
# basic constant
|
||||||
|
|
|
@ -2,10 +2,10 @@ import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import librosa
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import librosa
|
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
||||||
|
@ -386,9 +386,11 @@ class Xtts(BaseTTS):
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_speaker_embedding(self, audio, sr):
|
def get_speaker_embedding(self, audio, sr):
|
||||||
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
|
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
|
||||||
return self.hifigan_decoder.speaker_encoder.forward(
|
return (
|
||||||
audio_16k.to(self.device), l2_norm=True
|
self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True)
|
||||||
).unsqueeze(-1).to(self.device)
|
.unsqueeze(-1)
|
||||||
|
.to(self.device)
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_conditioning_latents(
|
def get_conditioning_latents(
|
||||||
|
@ -647,13 +649,19 @@ class Xtts(BaseTTS):
|
||||||
break
|
break
|
||||||
|
|
||||||
if decoder == "hifigan":
|
if decoder == "hifigan":
|
||||||
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
assert hasattr(
|
||||||
|
self, "hifigan_decoder"
|
||||||
|
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||||
elif decoder == "ne_hifigan":
|
elif decoder == "ne_hifigan":
|
||||||
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
assert hasattr(
|
||||||
|
self, "ne_hifigan_decoder"
|
||||||
|
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||||
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
|
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||||
else:
|
else:
|
||||||
assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
|
assert hasattr(
|
||||||
|
self, "diffusion_decoder"
|
||||||
|
), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
|
||||||
mel = do_spectrogram_diffusion(
|
mel = do_spectrogram_diffusion(
|
||||||
self.diffusion_decoder,
|
self.diffusion_decoder,
|
||||||
diffuser,
|
diffuser,
|
||||||
|
@ -742,10 +750,14 @@ class Xtts(BaseTTS):
|
||||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||||
if decoder == "hifigan":
|
if decoder == "hifigan":
|
||||||
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
assert hasattr(
|
||||||
|
self, "hifigan_decoder"
|
||||||
|
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||||
elif decoder == "ne_hifigan":
|
elif decoder == "ne_hifigan":
|
||||||
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
assert hasattr(
|
||||||
|
self, "ne_hifigan_decoder"
|
||||||
|
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||||
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Diffusion for streaming inference not implemented.")
|
raise NotImplementedError("Diffusion for streaming inference not implemented.")
|
||||||
|
@ -756,10 +768,14 @@ class Xtts(BaseTTS):
|
||||||
yield wav_chunk
|
yield wav_chunk
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
raise NotImplementedError(
|
||||||
|
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
||||||
|
)
|
||||||
|
|
||||||
def eval_step(self):
|
def eval_step(self):
|
||||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
raise NotImplementedError(
|
||||||
|
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument
|
def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument
|
||||||
|
@ -835,12 +851,18 @@ class Xtts(BaseTTS):
|
||||||
self.load_state_dict(checkpoint, strict=strict)
|
self.load_state_dict(checkpoint, strict=strict)
|
||||||
|
|
||||||
if eval:
|
if eval:
|
||||||
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
|
if hasattr(self, "hifigan_decoder"):
|
||||||
if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval()
|
self.hifigan_decoder.eval()
|
||||||
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
|
if hasattr(self, "ne_hifigan_decoder"):
|
||||||
if hasattr(self, "vocoder"): self.vocoder.eval()
|
self.hifigan_decoder.eval()
|
||||||
|
if hasattr(self, "diffusion_decoder"):
|
||||||
|
self.diffusion_decoder.eval()
|
||||||
|
if hasattr(self, "vocoder"):
|
||||||
|
self.vocoder.eval()
|
||||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
||||||
self.gpt.eval()
|
self.gpt.eval()
|
||||||
|
|
||||||
def train_step(self):
|
def train_step(self):
|
||||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
raise NotImplementedError(
|
||||||
|
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
||||||
|
)
|
||||||
|
|
|
@ -7,7 +7,6 @@ from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
|
|
||||||
# Logging parameters
|
# Logging parameters
|
||||||
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
||||||
PROJECT_NAME = "XTTS_trainer"
|
PROJECT_NAME = "XTTS_trainer"
|
||||||
|
@ -66,7 +65,9 @@ XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split(
|
||||||
# download XTTS v1.1 files if needed
|
# download XTTS v1.1 files if needed
|
||||||
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
||||||
print(" > Downloading XTTS v1.1 files!")
|
print(" > Downloading XTTS v1.1 files!")
|
||||||
ModelManager._download_model_files([TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
ModelManager._download_model_files(
|
||||||
|
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Training sentences generations
|
# Training sentences generations
|
||||||
|
|
|
@ -22,7 +22,4 @@ def test_synthesize():
|
||||||
)
|
)
|
||||||
|
|
||||||
# test pipe_out command
|
# test pipe_out command
|
||||||
run_cli(
|
run_cli(f'tts --text "test." --pipe_out --out_path "{output_path}" | aplay')
|
||||||
'tts --text "test." --pipe_out '
|
|
||||||
f'--out_path "{output_path}" | aplay'
|
|
||||||
)
|
|
||||||
|
|
Loading…
Reference in New Issue