Run `make style` & re-enable it in CI

This commit is contained in:
Aarni Koskela 2023-10-31 17:07:24 +02:00
parent 6fef4f9067
commit 4cfc3e5779
11 changed files with 141 additions and 95 deletions

View File

@ -42,6 +42,5 @@ jobs:
run: | run: |
python3 -m pip install .[all] python3 -m pip install .[all]
python3 setup.py egg_info python3 setup.py egg_info
# - name: Lint check - name: Style check
# run: | run: make style
# make lint

View File

@ -264,7 +264,7 @@ class TTS(nn.Module):
language: str = None, language: str = None,
emotion: str = None, emotion: str = None,
speed: float = 1.0, speed: float = 1.0,
pipe_out = None, pipe_out=None,
file_path: str = None, file_path: str = None,
) -> Union[np.ndarray, str]: ) -> Union[np.ndarray, str]:
"""Convert text to speech using Coqui Studio models. Use `CS_API` class if you are only interested in the API. """Convert text to speech using Coqui Studio models. Use `CS_API` class if you are only interested in the API.
@ -359,7 +359,7 @@ class TTS(nn.Module):
speaker_wav: str = None, speaker_wav: str = None,
emotion: str = None, emotion: str = None,
speed: float = 1.0, speed: float = 1.0,
pipe_out = None, pipe_out=None,
file_path: str = "output.wav", file_path: str = "output.wav",
**kwargs, **kwargs,
): ):
@ -460,7 +460,7 @@ class TTS(nn.Module):
""" """
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
# Lazy code... save it to a temp file to resample it while reading it for VC # Lazy code... save it to a temp file to resample it while reading it for VC
self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name,speaker_wav=speaker_wav) self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name, speaker_wav=speaker_wav)
if self.voice_converter is None: if self.voice_converter is None:
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24") self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav) wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)

View File

@ -427,7 +427,9 @@ def main():
tts_path = model_path tts_path = model_path
tts_config_path = config_path tts_config_path = config_path
if "default_vocoder" in model_item: if "default_vocoder" in model_item:
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name args.vocoder_name = (
model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
)
# voice conversion model # voice conversion model
if model_item["model_type"] == "voice_conversion_models": if model_item["model_type"] == "voice_conversion_models":

View File

@ -1,12 +1,12 @@
import json
import os import os
import re import re
import json
import torch
from tokenizers import Tokenizer
import pypinyin import pypinyin
import torch
from num2words import num2words from num2words import num2words
from tokenizers import Tokenizer
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
_whitespace_re = re.compile(r"\s+") _whitespace_re = re.compile(r"\s+")
@ -87,7 +87,7 @@ _abbreviations = {
"it": [ "it": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [ for x in [
#("sig.ra", "signora"), # ("sig.ra", "signora"),
("sig", "signore"), ("sig", "signore"),
("dr", "dottore"), ("dr", "dottore"),
("st", "santo"), ("st", "santo"),
@ -157,13 +157,15 @@ _abbreviations = {
], ],
} }
def expand_abbreviations_multilingual(text, lang='en'):
def expand_abbreviations_multilingual(text, lang="en"):
for regex, replacement in _abbreviations[lang]: for regex, replacement in _abbreviations[lang]:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
_symbols_multilingual = { _symbols_multilingual = {
'en': [ "en": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [ for x in [
("&", " and "), ("&", " and "),
@ -172,10 +174,10 @@ _symbols_multilingual = {
("#", " hash "), ("#", " hash "),
("$", " dollar "), ("$", " dollar "),
("£", " pound "), ("£", " pound "),
("°", " degree ") ("°", " degree "),
] ]
], ],
'es': [ "es": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [ for x in [
("&", " y "), ("&", " y "),
@ -184,10 +186,10 @@ _symbols_multilingual = {
("#", " numeral "), ("#", " numeral "),
("$", " dolar "), ("$", " dolar "),
("£", " libra "), ("£", " libra "),
("°", " grados ") ("°", " grados "),
] ]
], ],
'fr': [ "fr": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [ for x in [
("&", " et "), ("&", " et "),
@ -196,10 +198,10 @@ _symbols_multilingual = {
("#", " dièse "), ("#", " dièse "),
("$", " dollar "), ("$", " dollar "),
("£", " livre "), ("£", " livre "),
("°", " degrés ") ("°", " degrés "),
] ]
], ],
'de': [ "de": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [ for x in [
("&", " und "), ("&", " und "),
@ -208,10 +210,10 @@ _symbols_multilingual = {
("#", " raute "), ("#", " raute "),
("$", " dollar "), ("$", " dollar "),
("£", " pfund "), ("£", " pfund "),
("°", " grad ") ("°", " grad "),
] ]
], ],
'pt': [ "pt": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [ for x in [
("&", " e "), ("&", " e "),
@ -220,10 +222,10 @@ _symbols_multilingual = {
("#", " cardinal "), ("#", " cardinal "),
("$", " dólar "), ("$", " dólar "),
("£", " libra "), ("£", " libra "),
("°", " graus ") ("°", " graus "),
] ]
], ],
'it': [ "it": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [ for x in [
("&", " e "), ("&", " e "),
@ -232,10 +234,10 @@ _symbols_multilingual = {
("#", " cancelletto "), ("#", " cancelletto "),
("$", " dollaro "), ("$", " dollaro "),
("£", " sterlina "), ("£", " sterlina "),
("°", " gradi ") ("°", " gradi "),
] ]
], ],
'pl': [ "pl": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [ for x in [
("&", " i "), ("&", " i "),
@ -244,7 +246,7 @@ _symbols_multilingual = {
("#", " krzyżyk "), ("#", " krzyżyk "),
("$", " dolar "), ("$", " dolar "),
("£", " funt "), ("£", " funt "),
("°", " stopnie ") ("°", " stopnie "),
] ]
], ],
"ar": [ "ar": [
@ -257,7 +259,7 @@ _symbols_multilingual = {
("#", " رقم "), ("#", " رقم "),
("$", " دولار "), ("$", " دولار "),
("£", " جنيه "), ("£", " جنيه "),
("°", " درجة ") ("°", " درجة "),
] ]
], ],
"zh-cn": [ "zh-cn": [
@ -270,7 +272,7 @@ _symbols_multilingual = {
("#", ""), ("#", ""),
("$", " 美元 "), ("$", " 美元 "),
("£", " 英镑 "), ("£", " 英镑 "),
("°", "") ("°", ""),
] ]
], ],
"cs": [ "cs": [
@ -283,7 +285,7 @@ _symbols_multilingual = {
("#", " křížek "), ("#", " křížek "),
("$", " dolar "), ("$", " dolar "),
("£", " libra "), ("£", " libra "),
("°", " stupně ") ("°", " stupně "),
] ]
], ],
"ru": [ "ru": [
@ -296,7 +298,7 @@ _symbols_multilingual = {
("#", " номер "), ("#", " номер "),
("$", " доллар "), ("$", " доллар "),
("£", " фунт "), ("£", " фунт "),
("°", " градус ") ("°", " градус "),
] ]
], ],
"nl": [ "nl": [
@ -309,7 +311,7 @@ _symbols_multilingual = {
("#", " hekje "), ("#", " hekje "),
("$", " dollar "), ("$", " dollar "),
("£", " pond "), ("£", " pond "),
("°", " graden ") ("°", " graden "),
] ]
], ],
"tr": [ "tr": [
@ -321,15 +323,16 @@ _symbols_multilingual = {
("#", " diyez "), ("#", " diyez "),
("$", " dolar "), ("$", " dolar "),
("£", " sterlin "), ("£", " sterlin "),
("°", " derece ") ("°", " derece "),
] ]
], ],
} }
def expand_symbols_multilingual(text, lang='en'):
def expand_symbols_multilingual(text, lang="en"):
for regex, replacement in _symbols_multilingual[lang]: for regex, replacement in _symbols_multilingual[lang]:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
text = text.replace(' ', ' ') # Ensure there are no double spaces text = text.replace(" ", " ") # Ensure there are no double spaces
return text.strip() return text.strip()
@ -349,34 +352,38 @@ _ordinal_re = {
} }
_number_re = re.compile(r"[0-9]+") _number_re = re.compile(r"[0-9]+")
_currency_re = { _currency_re = {
'USD': re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"), "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
'GBP': re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"), "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
'EUR': re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))") "EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
} }
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b") _comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b") _dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)") _decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
def _remove_commas(m): def _remove_commas(m):
text = m.group(0) text = m.group(0)
if "," in text: if "," in text:
text = text.replace(",", "") text = text.replace(",", "")
return text return text
def _remove_dots(m): def _remove_dots(m):
text = m.group(0) text = m.group(0)
if "." in text: if "." in text:
text = text.replace(".", "") text = text.replace(".", "")
return text return text
def _expand_decimal_point(m, lang='en'):
def _expand_decimal_point(m, lang="en"):
amount = m.group(1).replace(",", ".") amount = m.group(1).replace(",", ".")
return num2words(float(amount), lang=lang if lang != "cs" else "cz") return num2words(float(amount), lang=lang if lang != "cs" else "cz")
def _expand_currency(m, lang='en', currency='USD'):
amount = float((re.sub(r'[^\d.]', '', m.group(0).replace(",", ".")))) def _expand_currency(m, lang="en", currency="USD"):
full_amount = num2words(amount, to='currency', currency=currency, lang=lang if lang != "cs" else "cz") amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
and_equivalents = { and_equivalents = {
"en": ", ", "en": ", ",
@ -400,13 +407,16 @@ def _expand_currency(m, lang='en', currency='USD'):
return full_amount return full_amount
def _expand_ordinal(m, lang='en'):
def _expand_ordinal(m, lang="en"):
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz") return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
def _expand_number(m, lang='en'):
def _expand_number(m, lang="en"):
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz") return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
def expand_numbers_multilingual(text, lang='en'):
def expand_numbers_multilingual(text, lang="en"):
if lang == "zh-cn": if lang == "zh-cn":
text = zh_num2words()(text) text = zh_num2words()(text)
else: else:
@ -415,9 +425,9 @@ def expand_numbers_multilingual(text, lang='en'):
else: else:
text = re.sub(_dot_number_re, _remove_dots, text) text = re.sub(_dot_number_re, _remove_dots, text)
try: try:
text = re.sub(_currency_re['GBP'], lambda m: _expand_currency(m, lang, 'GBP'), text) text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
text = re.sub(_currency_re['USD'], lambda m: _expand_currency(m, lang, 'USD'), text) text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
text = re.sub(_currency_re['EUR'], lambda m: _expand_currency(m, lang, 'EUR'), text) text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
except: except:
pass pass
if lang != "tr": if lang != "tr":
@ -426,15 +436,18 @@ def expand_numbers_multilingual(text, lang='en'):
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text) text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
return text return text
def lowercase(text): def lowercase(text):
return text.lower() return text.lower()
def collapse_whitespace(text): def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text) return re.sub(_whitespace_re, " ", text)
def multilingual_cleaners(text, lang): def multilingual_cleaners(text, lang):
text = text.replace('"', '') text = text.replace('"', "")
if lang=="tr": if lang == "tr":
text = text.replace("İ", "i") text = text.replace("İ", "i")
text = text.replace("Ö", "ö") text = text.replace("Ö", "ö")
text = text.replace("Ü", "ü") text = text.replace("Ü", "ü")
@ -445,20 +458,26 @@ def multilingual_cleaners(text, lang):
text = collapse_whitespace(text) text = collapse_whitespace(text)
return text return text
def basic_cleaners(text): def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration.""" """Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text) text = lowercase(text)
text = collapse_whitespace(text) text = collapse_whitespace(text)
return text return text
def chinese_transliterate(text): def chinese_transliterate(text):
return "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]) return "".join(
p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)
)
def japanese_cleaners(text, katsu): def japanese_cleaners(text, katsu):
text = katsu.romaji(text) text = katsu.romaji(text)
text = lowercase(text) text = lowercase(text)
return text return text
class VoiceBpeTokenizer: class VoiceBpeTokenizer:
def __init__(self, vocab_file=None, preprocess=None): def __init__(self, vocab_file=None, preprocess=None):
self.tokenizer = None self.tokenizer = None
@ -485,6 +504,7 @@ class VoiceBpeTokenizer:
elif lang == "ja": elif lang == "ja":
if self.katsu is None: if self.katsu is None:
import cutlet import cutlet
self.katsu = cutlet.Cutlet() self.katsu = cutlet.Cutlet()
txt = japanese_cleaners(txt, self.katsu) txt = japanese_cleaners(txt, self.katsu)
else: else:

View File

@ -2,9 +2,14 @@
# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) # 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
# 2019.9 - 2022 Jiayu DU # 2019.9 - 2022 Jiayu DU
import sys, os, argparse import argparse
import string, re
import csv import csv
import os
import re
import string
import sys
# fmt: off
# ================================================================================ # # ================================================================================ #
# basic constant # basic constant

View File

@ -2,10 +2,10 @@ import os
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
import librosa
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
import librosa
from coqpit import Coqpit from coqpit import Coqpit
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
@ -386,9 +386,11 @@ class Xtts(BaseTTS):
@torch.inference_mode() @torch.inference_mode()
def get_speaker_embedding(self, audio, sr): def get_speaker_embedding(self, audio, sr):
audio_16k = torchaudio.functional.resample(audio, sr, 16000) audio_16k = torchaudio.functional.resample(audio, sr, 16000)
return self.hifigan_decoder.speaker_encoder.forward( return (
audio_16k.to(self.device), l2_norm=True self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True)
).unsqueeze(-1).to(self.device) .unsqueeze(-1)
.to(self.device)
)
@torch.inference_mode() @torch.inference_mode()
def get_conditioning_latents( def get_conditioning_latents(
@ -647,13 +649,19 @@ class Xtts(BaseTTS):
break break
if decoder == "hifigan": if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" assert hasattr(
self, "hifigan_decoder"
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
elif decoder == "ne_hifigan": elif decoder == "ne_hifigan":
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" assert hasattr(
self, "ne_hifigan_decoder"
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding) wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
else: else:
assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`" assert hasattr(
self, "diffusion_decoder"
), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
mel = do_spectrogram_diffusion( mel = do_spectrogram_diffusion(
self.diffusion_decoder, self.diffusion_decoder,
diffuser, diffuser,
@ -742,10 +750,14 @@ class Xtts(BaseTTS):
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :] gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if decoder == "hifigan": if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" assert hasattr(
self, "hifigan_decoder"
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
elif decoder == "ne_hifigan": elif decoder == "ne_hifigan":
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" assert hasattr(
self, "ne_hifigan_decoder"
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
else: else:
raise NotImplementedError("Diffusion for streaming inference not implemented.") raise NotImplementedError("Diffusion for streaming inference not implemented.")
@ -756,10 +768,14 @@ class Xtts(BaseTTS):
yield wav_chunk yield wav_chunk
def forward(self): def forward(self):
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training") raise NotImplementedError(
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
)
def eval_step(self): def eval_step(self):
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training") raise NotImplementedError(
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
)
@staticmethod @staticmethod
def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument
@ -835,12 +851,18 @@ class Xtts(BaseTTS):
self.load_state_dict(checkpoint, strict=strict) self.load_state_dict(checkpoint, strict=strict)
if eval: if eval:
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval() if hasattr(self, "hifigan_decoder"):
if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval() self.hifigan_decoder.eval()
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval() if hasattr(self, "ne_hifigan_decoder"):
if hasattr(self, "vocoder"): self.vocoder.eval() self.hifigan_decoder.eval()
if hasattr(self, "diffusion_decoder"):
self.diffusion_decoder.eval()
if hasattr(self, "vocoder"):
self.vocoder.eval()
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
self.gpt.eval() self.gpt.eval()
def train_step(self): def train_step(self):
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training") raise NotImplementedError(
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
)

View File

@ -428,7 +428,7 @@ def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False,
return x return x
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out = None, **kwargs) -> None: def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out=None, **kwargs) -> None:
"""Save float waveform to a file using Scipy. """Save float waveform to a file using Scipy.
Args: Args:

View File

@ -694,7 +694,7 @@ class AudioProcessor(object):
x = self.rms_volume_norm(x, self.db_level) x = self.rms_volume_norm(x, self.db_level)
return x return x
def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out = None) -> None: def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None:
"""Save a waveform to a file using Scipy. """Save a waveform to a file using Scipy.
Args: Args:

View File

@ -235,7 +235,7 @@ class Synthesizer(nn.Module):
""" """
return self.seg.segment(text) return self.seg.segment(text)
def save_wav(self, wav: List[int], path: str, pipe_out = None) -> None: def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None:
"""Save the waveform as a file. """Save the waveform as a file.
Args: Args:

View File

@ -7,7 +7,6 @@ from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
# Logging parameters # Logging parameters
RUN_NAME = "GPT_XTTS_LJSpeech_FT" RUN_NAME = "GPT_XTTS_LJSpeech_FT"
PROJECT_NAME = "XTTS_trainer" PROJECT_NAME = "XTTS_trainer"
@ -66,7 +65,9 @@ XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split(
# download XTTS v1.1 files if needed # download XTTS v1.1 files if needed
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
print(" > Downloading XTTS v1.1 files!") print(" > Downloading XTTS v1.1 files!")
ModelManager._download_model_files([TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) ModelManager._download_model_files(
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
)
# Training sentences generations # Training sentences generations

View File

@ -22,7 +22,4 @@ def test_synthesize():
) )
# test pipe_out command # test pipe_out command
run_cli( run_cli(f'tts --text "test." --pipe_out --out_path "{output_path}" | aplay')
'tts --text "test." --pipe_out '
f'--out_path "{output_path}" | aplay'
)