Merge pull request #3251 from coqui-ai/dev

v0.20.6
This commit is contained in:
Eren Gölge 2023-11-21 13:22:47 +01:00 committed by GitHub
commit 2211ba267a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 416 additions and 503 deletions

View File

@ -10,7 +10,7 @@ jobs:
build-sdist: build-sdist:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Verify tag matches version - name: Verify tag matches version
run: | run: |
set -ex set -ex
@ -38,7 +38,7 @@ jobs:
matrix: matrix:
python-version: ["3.9", "3.10", "3.11"] python-version: ["3.9", "3.10", "3.11"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- uses: actions/setup-python@v2 - uses: actions/setup-python@v2
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@ -3,14 +3,14 @@
"multilingual": { "multilingual": {
"multi-dataset": { "multi-dataset": {
"xtts_v2": { "xtts_v2": {
"description": "XTTS-v2 by Coqui with 16 languages.", "description": "XTTS-v2.0.2 by Coqui with 16 languages.",
"hf_url": [ "hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5" "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
], ],
"model_hash": "6a09d1ad43896f06041ed8195956c9698f13b6189dc80f1c74bdc2b8e8d15324", "model_hash": "5ce0502bfe3bc88dc8d9312b12a7558c",
"default_vocoder": null, "default_vocoder": null,
"commit": "480a6cdf7", "commit": "480a6cdf7",
"license": "CPML", "license": "CPML",

View File

@ -1 +1 @@
0.20.5 0.20.6

View File

@ -15,6 +15,7 @@ from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import quantize
from TTS.utils.generic_utils import count_parameters from TTS.utils.generic_utils import count_parameters
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
@ -159,7 +160,7 @@ def inference(
def extract_spectrograms( def extract_spectrograms(
data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt" data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt"
): ):
model.eval() model.eval()
export_metadata = [] export_metadata = []
@ -196,8 +197,8 @@ def extract_spectrograms(
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
# quantize and save wav # quantize and save wav
if quantized_wav: if quantize_bits > 0:
wavq = ap.quantize(wav) wavq = quantize(wav, quantize_bits)
np.save(wavq_path, wavq) np.save(wavq_path, wavq)
# save TTS mel # save TTS mel
@ -263,7 +264,7 @@ def main(args): # pylint: disable=redefined-outer-name
model, model,
ap, ap,
args.output_path, args.output_path,
quantized_wav=args.quantized, quantize_bits=args.quantize_bits,
save_audio=args.save_audio, save_audio=args.save_audio,
debug=args.debug, debug=args.debug,
metada_name="metada.txt", metada_name="metada.txt",
@ -277,7 +278,7 @@ if __name__ == "__main__":
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True) parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files") parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
parser.add_argument("--eval", type=bool, help="compute eval.", default=True) parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
args = parser.parse_args() args = parser.parse_args()

View File

@ -13,12 +13,18 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch as th import torch as th
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
from tqdm import tqdm from tqdm import tqdm
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m} try:
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m}
except ImportError:
K_DIFFUSION_SAMPLERS = None
SAMPLERS = ["dpm++2m", "p", "ddim"] SAMPLERS = ["dpm++2m", "p", "ddim"]
@ -531,6 +537,8 @@ class GaussianDiffusion:
if self.conditioning_free is not True: if self.conditioning_free is not True:
raise RuntimeError("cond_free must be true") raise RuntimeError("cond_free must be true")
with tqdm(total=self.num_timesteps) as pbar: with tqdm(total=self.num_timesteps) as pbar:
if K_DIFFUSION_SAMPLERS is None:
raise ModuleNotFoundError("Install k_diffusion for using k_diffusion samplers")
return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs) return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
else: else:
raise RuntimeError("sampler not impl") raise RuntimeError("sampler not impl")

View File

@ -441,7 +441,9 @@ class GPT(nn.Module):
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token) audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
# Pad mel codes with stop_audio_token # Pad mel codes with stop_audio_token
audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet audio_codes = self.set_mel_padding(
audio_codes, code_lengths - 3
) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
# Build input and target tensors # Build input and target tensors
# Prepend start token to inputs and append stop token to targets # Prepend start token to inputs and append stop token to targets

View File

@ -1,6 +1,6 @@
import json
import os import os
import re import re
import textwrap
from functools import cached_property from functools import cached_property
import pypinyin import pypinyin
@ -8,10 +8,66 @@ import torch
from hangul_romanize import Transliter from hangul_romanize import Transliter
from hangul_romanize.rule import academic from hangul_romanize.rule import academic
from num2words import num2words from num2words import num2words
from spacy.lang.ar import Arabic
from spacy.lang.en import English
from spacy.lang.es import Spanish
from spacy.lang.ja import Japanese
from spacy.lang.zh import Chinese
from tokenizers import Tokenizer from tokenizers import Tokenizer
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
def get_spacy_lang(lang):
if lang == "zh":
return Chinese()
elif lang == "ja":
return Japanese()
elif lang == "ar":
return Arabic()
elif lang == "es":
return Spanish()
else:
# For most languages, Enlish does the job
return English()
def split_sentence(text, lang, text_split_length=250):
"""Preprocess the input text"""
text_splits = []
if text_split_length is not None and len(text) >= text_split_length:
text_splits.append("")
nlp = get_spacy_lang(lang)
nlp.add_pipe("sentencizer")
doc = nlp(text)
for sentence in doc.sents:
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
# if the last sentence + the current sentence is less than the text_split_length
# then add the current sentence to the last sentence
text_splits[-1] += " " + str(sentence)
text_splits[-1] = text_splits[-1].lstrip()
elif len(str(sentence)) > text_split_length:
# if the current sentence is greater than the text_split_length
for line in textwrap.wrap(
str(sentence),
width=text_split_length,
drop_whitespace=True,
break_on_hyphens=False,
tabsize=1,
):
text_splits.append(str(line))
else:
text_splits.append(str(sentence))
if len(text_splits) > 1:
if text_splits[0] == "":
del text_splits[0]
else:
text_splits = [text.lstrip()]
return text_splits
_whitespace_re = re.compile(r"\s+") _whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations: # List of (regular expression, replacement) pairs for abbreviations:
@ -115,7 +171,7 @@ _abbreviations = {
# There are not many common abbreviations in Arabic as in English. # There are not many common abbreviations in Arabic as in English.
] ]
], ],
"zh-cn": [ "zh": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [ for x in [
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts. # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
@ -280,7 +336,7 @@ _symbols_multilingual = {
("°", " درجة "), ("°", " درجة "),
] ]
], ],
"zh-cn": [ "zh": [
# Chinese # Chinese
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [ for x in [
@ -464,7 +520,7 @@ def _expand_number(m, lang="en"):
def expand_numbers_multilingual(text, lang="en"): def expand_numbers_multilingual(text, lang="en"):
if lang == "zh" or lang == "zh-cn": if lang == "zh":
text = zh_num2words()(text) text = zh_num2words()(text)
else: else:
if lang in ["en", "ru"]: if lang in ["en", "ru"]:
@ -525,7 +581,7 @@ def japanese_cleaners(text, katsu):
return text return text
def korean_cleaners(text): def korean_transliterate(text):
r = Transliter(academic) r = Transliter(academic)
return r.translit(text) return r.translit(text)
@ -546,7 +602,7 @@ class VoiceBpeTokenizer:
"it": 213, "it": 213,
"pt": 203, "pt": 203,
"pl": 224, "pl": 224,
"zh-cn": 82, "zh": 82,
"ar": 166, "ar": 166,
"cs": 186, "cs": 186,
"ru": 182, "ru": 182,
@ -564,6 +620,7 @@ class VoiceBpeTokenizer:
return cutlet.Cutlet() return cutlet.Cutlet()
def check_input_length(self, txt, lang): def check_input_length(self, txt, lang):
lang = lang.split("-")[0] # remove the region
limit = self.char_limits.get(lang, 250) limit = self.char_limits.get(lang, 250)
if len(txt) > limit: if len(txt) > limit:
print( print(
@ -571,21 +628,23 @@ class VoiceBpeTokenizer:
) )
def preprocess_text(self, txt, lang): def preprocess_text(self, txt, lang):
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}: if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
txt = multilingual_cleaners(txt, lang) txt = multilingual_cleaners(txt, lang)
if lang in {"zh", "zh-cn"}: if lang == "zh":
txt = chinese_transliterate(txt) txt = chinese_transliterate(txt)
if lang == "ko":
txt = korean_transliterate(txt)
elif lang == "ja": elif lang == "ja":
txt = japanese_cleaners(txt, self.katsu) txt = japanese_cleaners(txt, self.katsu)
elif lang == "ko":
txt = korean_cleaners(txt)
else: else:
raise NotImplementedError(f"Language '{lang}' is not supported.") raise NotImplementedError(f"Language '{lang}' is not supported.")
return txt return txt
def encode(self, txt, lang): def encode(self, txt, lang):
lang = lang.split("-")[0] # remove the region
self.check_input_length(txt, lang) self.check_input_length(txt, lang)
txt = self.preprocess_text(txt, lang) txt = self.preprocess_text(txt, lang)
lang = "zh-cn" if lang == "zh" else lang
txt = f"[{lang}]{txt}" txt = f"[{lang}]{txt}"
txt = txt.replace(" ", "[SPACE]") txt = txt.replace(" ", "[SPACE]")
return self.tokenizer.encode(txt).ids return self.tokenizer.encode(txt).ids
@ -682,8 +741,8 @@ def test_expand_numbers_multilingual():
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"), ("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"), ("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
# Chinese (Simplified) # Chinese (Simplified)
("在12.5秒内", "在十二点五秒内", "zh-cn"), ("在12.5秒内", "在十二点五秒内", "zh"),
("有50名士兵", "有五十名士兵", "zh-cn"), ("有50名士兵", "有五十名士兵", "zh"),
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work # ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
# ("那将是20€先生", '那将是二十欧元先生', 'zh'), # ("那将是20€先生", '那将是二十欧元先生', 'zh'),
# Turkish # Turkish
@ -764,7 +823,7 @@ def test_symbols_multilingual():
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"), ("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"), ("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"), ("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
("我的电量为 14%", "我的电量为 14 百分之", "zh-cn"), ("我的电量为 14%", "我的电量为 14 百分之", "zh"),
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"), ("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"), ("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"), ("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),

View File

@ -318,9 +318,10 @@ class GPTTrainer(BaseTTS):
batch["cond_idxs"] = None batch["cond_idxs"] = None
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def on_epoch_start(self, trainer): # pylint: disable=W0613 def on_train_epoch_start(self, trainer):
# guarante that dvae will be in eval mode after .train() on evaluation end trainer.model.eval() # the whole model to eval
self.dvae = self.dvae.eval() # put gpt model in training mode
trainer.model.xtts.gpt.train()
def on_init_end(self, trainer): # pylint: disable=W0613 def on_init_end(self, trainer): # pylint: disable=W0613
# ignore similarities.pth on clearml save/upload # ignore similarities.pth on clearml save/upload

View File

@ -10,7 +10,7 @@ from coqpit import Coqpit
from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
@ -396,7 +396,7 @@ class Xtts(BaseTTS):
inference with config inference with config
""" """
assert ( assert (
language in self.config.languages "zh-cn" if language == "zh" else language in self.config.languages
), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}" ), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}"
# Use generally found best tuning knobs for generation. # Use generally found best tuning knobs for generation.
settings = { settings = {
@ -420,9 +420,9 @@ class Xtts(BaseTTS):
ref_audio_path, ref_audio_path,
language, language,
# GPT inference # GPT inference
temperature=0.65, temperature=0.75,
length_penalty=1, length_penalty=1.0,
repetition_penalty=2.0, repetition_penalty=10.0,
top_k=50, top_k=50,
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
@ -502,71 +502,76 @@ class Xtts(BaseTTS):
gpt_cond_latent, gpt_cond_latent,
speaker_embedding, speaker_embedding,
# GPT inference # GPT inference
temperature=0.65, temperature=0.75,
length_penalty=1, length_penalty=1.0,
repetition_penalty=2.0, repetition_penalty=10.0,
top_k=50, top_k=50,
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
num_beams=1, num_beams=1,
speed=1.0, speed=1.0,
enable_text_splitting=False,
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
language = language.split("-")[0] # remove the country code
length_scale = 1.0 / max(speed, 0.05) length_scale = 1.0 / max(speed, 0.05)
text = text.strip().lower() if enable_text_splitting:
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) text = split_sentence(text, language, self.tokenizer.char_limits[language])
else:
text = [text]
# print(" > Input text: ", text) wavs = []
# print(" > Input text preprocessed: ",self.tokenizer.preprocess_text(text, language)) gpt_latents_list = []
# print(" > Input tokens: ", text_tokens) for sent in text:
# print(" > Decoded text: ", self.tokenizer.decode(text_tokens[0].cpu().numpy())) sent = sent.strip().lower()
assert ( text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
with torch.no_grad(): assert (
gpt_codes = self.gpt.generate( text_tokens.shape[-1] < self.args.gpt_max_text_tokens
cond_latents=gpt_cond_latent, ), " ❗ XTTS can only generate text with a maximum of 400 tokens."
text_inputs=text_tokens,
input_tokens=None,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=self.gpt_batch_size,
num_beams=num_beams,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
output_attentions=False,
**hf_generate_kwargs,
)
expected_output_len = torch.tensor(
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
)
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) with torch.no_grad():
gpt_latents = self.gpt( gpt_codes = self.gpt.generate(
text_tokens, cond_latents=gpt_cond_latent,
text_len, text_inputs=text_tokens,
gpt_codes, input_tokens=None,
expected_output_len, do_sample=do_sample,
cond_latents=gpt_cond_latent, top_p=top_p,
return_attentions=False, top_k=top_k,
return_latent=True, temperature=temperature,
) num_return_sequences=self.gpt_batch_size,
num_beams=num_beams,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
output_attentions=False,
**hf_generate_kwargs,
)
expected_output_len = torch.tensor(
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
)
if length_scale != 1.0: text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
gpt_latents = F.interpolate( gpt_latents = self.gpt(
gpt_latents.transpose(1, 2), text_tokens,
scale_factor=length_scale, text_len,
mode="linear" gpt_codes,
).transpose(1, 2) expected_output_len,
cond_latents=gpt_cond_latent,
return_attentions=False,
return_latent=True,
)
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)
gpt_latents_list.append(gpt_latents.cpu())
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
return { return {
"wav": wav.cpu().numpy().squeeze(), "wav": torch.cat(wavs, dim=0).numpy(),
"gpt_latents": gpt_latents, "gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(),
"speaker_embedding": speaker_embedding, "speaker_embedding": speaker_embedding,
} }
@ -606,66 +611,76 @@ class Xtts(BaseTTS):
stream_chunk_size=20, stream_chunk_size=20,
overlap_wav_len=1024, overlap_wav_len=1024,
# GPT inference # GPT inference
temperature=0.65, temperature=0.75,
length_penalty=1, length_penalty=1.0,
repetition_penalty=2.0, repetition_penalty=10.0,
top_k=50, top_k=50,
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
speed=1.0, speed=1.0,
enable_text_splitting=False,
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
language = language.split("-")[0] # remove the country code
length_scale = 1.0 / max(speed, 0.05) length_scale = 1.0 / max(speed, 0.05)
text = text.strip().lower() if enable_text_splitting:
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) text = split_sentence(text, language, self.tokenizer.char_limits[language])
else:
text = [text]
fake_inputs = self.gpt.compute_embeddings( for sent in text:
gpt_cond_latent.to(self.device), sent = sent.strip().lower()
text_tokens, text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
)
gpt_generator = self.gpt.get_generator(
fake_inputs=fake_inputs,
top_k=top_k,
top_p=top_p,
temperature=temperature,
do_sample=do_sample,
num_beams=1,
num_return_sequences=1,
length_penalty=float(length_penalty),
repetition_penalty=float(repetition_penalty),
output_attentions=False,
output_hidden_states=True,
**hf_generate_kwargs,
)
last_tokens = [] assert (
all_latents = [] text_tokens.shape[-1] < self.args.gpt_max_text_tokens
wav_gen_prev = None ), " ❗ XTTS can only generate text with a maximum of 400 tokens."
wav_overlap = None
is_end = False
while not is_end: fake_inputs = self.gpt.compute_embeddings(
try: gpt_cond_latent.to(self.device),
x, latent = next(gpt_generator) text_tokens,
last_tokens += [x] )
all_latents += [latent] gpt_generator = self.gpt.get_generator(
except StopIteration: fake_inputs=fake_inputs,
is_end = True top_k=top_k,
top_p=top_p,
temperature=temperature,
do_sample=do_sample,
num_beams=1,
num_return_sequences=1,
length_penalty=float(length_penalty),
repetition_penalty=float(repetition_penalty),
output_attentions=False,
output_hidden_states=True,
**hf_generate_kwargs,
)
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): last_tokens = []
gpt_latents = torch.cat(all_latents, dim=0)[None, :] all_latents = []
if length_scale != 1.0: wav_gen_prev = None
gpt_latents = F.interpolate( wav_overlap = None
gpt_latents.transpose(1, 2), is_end = False
scale_factor=length_scale,
mode="linear" while not is_end:
).transpose(1, 2) try:
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) x, latent = next(gpt_generator)
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( last_tokens += [x]
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len all_latents += [latent]
) except StopIteration:
last_tokens = [] is_end = True
yield wav_chunk
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
)
last_tokens = []
yield wav_chunk
def forward(self): def forward(self):
raise NotImplementedError( raise NotImplementedError(

View File

@ -201,7 +201,6 @@ def stft(
def istft( def istft(
*, *,
y: np.ndarray = None, y: np.ndarray = None,
fft_size: int = None,
hop_length: int = None, hop_length: int = None,
win_length: int = None, win_length: int = None,
window: str = "hann", window: str = "hann",

View File

@ -5,10 +5,26 @@ import librosa
import numpy as np import numpy as np
import scipy.io.wavfile import scipy.io.wavfile
import scipy.signal import scipy.signal
import soundfile as sf
from TTS.tts.utils.helpers import StandardScaler from TTS.tts.utils.helpers import StandardScaler
from TTS.utils.audio.numpy_transforms import compute_f0 from TTS.utils.audio.numpy_transforms import (
amp_to_db,
build_mel_basis,
compute_f0,
db_to_amp,
deemphasis,
find_endpoint,
griffin_lim,
load_wav,
mel_to_spec,
millisec_to_length,
preemphasis,
rms_volume_norm,
spec_to_mel,
stft,
trim_silence,
volume_norm,
)
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
@ -200,7 +216,9 @@ class AudioProcessor(object):
# setup stft parameters # setup stft parameters
if hop_length is None: if hop_length is None:
# compute stft parameters from given time values # compute stft parameters from given time values
self.hop_length, self.win_length = self._stft_parameters() self.win_length, self.hop_length = millisec_to_length(
frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate
)
else: else:
# use stft parameters from config file # use stft parameters from config file
self.hop_length = hop_length self.hop_length = hop_length
@ -215,8 +233,13 @@ class AudioProcessor(object):
for key, value in members.items(): for key, value in members.items():
print(" | > {}:{}".format(key, value)) print(" | > {}:{}".format(key, value))
# create spectrogram utils # create spectrogram utils
self.mel_basis = self._build_mel_basis() self.mel_basis = build_mel_basis(
self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) sample_rate=self.sample_rate,
fft_size=self.fft_size,
num_mels=self.num_mels,
mel_fmax=self.mel_fmax,
mel_fmin=self.mel_fmin,
)
# setup scaler # setup scaler
if stats_path and signal_norm: if stats_path and signal_norm:
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
@ -232,35 +255,6 @@ class AudioProcessor(object):
return AudioProcessor(verbose=verbose, **config.audio) return AudioProcessor(verbose=verbose, **config.audio)
return AudioProcessor(verbose=verbose, **config) return AudioProcessor(verbose=verbose, **config)
### setting up the parameters ###
def _build_mel_basis(
self,
) -> np.ndarray:
"""Build melspectrogram basis.
Returns:
np.ndarray: melspectrogram basis.
"""
if self.mel_fmax is not None:
assert self.mel_fmax <= self.sample_rate // 2
return librosa.filters.mel(
sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
)
def _stft_parameters(
self,
) -> Tuple[int, int]:
"""Compute the real STFT parameters from the time values.
Returns:
Tuple[int, int]: hop length and window length for STFT.
"""
factor = self.frame_length_ms / self.frame_shift_ms
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(hop_length * factor)
return hop_length, win_length
### normalization ### ### normalization ###
def normalize(self, S: np.ndarray) -> np.ndarray: def normalize(self, S: np.ndarray) -> np.ndarray:
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
@ -386,31 +380,6 @@ class AudioProcessor(object):
self.linear_scaler = StandardScaler() self.linear_scaler = StandardScaler()
self.linear_scaler.set_stats(linear_mean, linear_std) self.linear_scaler.set_stats(linear_mean, linear_std)
### DB and AMP conversion ###
# pylint: disable=no-self-use
def _amp_to_db(self, x: np.ndarray) -> np.ndarray:
"""Convert amplitude values to decibels.
Args:
x (np.ndarray): Amplitude spectrogram.
Returns:
np.ndarray: Decibels spectrogram.
"""
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
# pylint: disable=no-self-use
def _db_to_amp(self, x: np.ndarray) -> np.ndarray:
"""Convert decibels spectrogram to amplitude spectrogram.
Args:
x (np.ndarray): Decibels spectrogram.
Returns:
np.ndarray: Amplitude spectrogram.
"""
return _exp(x / self.spec_gain, self.base)
### Preemphasis ### ### Preemphasis ###
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
@ -424,32 +393,13 @@ class AudioProcessor(object):
Returns: Returns:
np.ndarray: Decorrelated audio signal. np.ndarray: Decorrelated audio signal.
""" """
if self.preemphasis == 0: return preemphasis(x=x, coef=self.preemphasis)
raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1, -self.preemphasis], [1], x)
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
"""Reverse pre-emphasis.""" """Reverse pre-emphasis."""
if self.preemphasis == 0: return deemphasis(x=x, coef=self.preemphasis)
raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1], [1, -self.preemphasis], x)
### SPECTROGRAMs ### ### SPECTROGRAMs ###
def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray:
"""Project a full scale spectrogram to a melspectrogram.
Args:
spectrogram (np.ndarray): Full scale spectrogram.
Returns:
np.ndarray: Melspectrogram
"""
return np.dot(self.mel_basis, spectrogram)
def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray:
"""Convert a melspectrogram to full scale spectrogram."""
return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec))
def spectrogram(self, y: np.ndarray) -> np.ndarray: def spectrogram(self, y: np.ndarray) -> np.ndarray:
"""Compute a spectrogram from a waveform. """Compute a spectrogram from a waveform.
@ -460,11 +410,16 @@ class AudioProcessor(object):
np.ndarray: Spectrogram. np.ndarray: Spectrogram.
""" """
if self.preemphasis != 0: if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y)) y = self.apply_preemphasis(y)
else: D = stft(
D = self._stft(y) y=y,
fft_size=self.fft_size,
hop_length=self.hop_length,
win_length=self.win_length,
pad_mode=self.stft_pad_mode,
)
if self.do_amp_to_db_linear: if self.do_amp_to_db_linear:
S = self._amp_to_db(np.abs(D)) S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base)
else: else:
S = np.abs(D) S = np.abs(D)
return self.normalize(S).astype(np.float32) return self.normalize(S).astype(np.float32)
@ -472,32 +427,35 @@ class AudioProcessor(object):
def melspectrogram(self, y: np.ndarray) -> np.ndarray: def melspectrogram(self, y: np.ndarray) -> np.ndarray:
"""Compute a melspectrogram from a waveform.""" """Compute a melspectrogram from a waveform."""
if self.preemphasis != 0: if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y)) y = self.apply_preemphasis(y)
else: D = stft(
D = self._stft(y) y=y,
fft_size=self.fft_size,
hop_length=self.hop_length,
win_length=self.win_length,
pad_mode=self.stft_pad_mode,
)
S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis)
if self.do_amp_to_db_mel: if self.do_amp_to_db_mel:
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
else:
S = self._linear_to_mel(np.abs(D))
return self.normalize(S).astype(np.float32) return self.normalize(S).astype(np.float32)
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" """Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
S = self.denormalize(spectrogram) S = self.denormalize(spectrogram)
S = self._db_to_amp(S) S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
# Reconstruct phase # Reconstruct phase
if self.preemphasis != 0: W = self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
return self._griffin_lim(S**self.power)
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" """Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
D = self.denormalize(mel_spectrogram) D = self.denormalize(mel_spectrogram)
S = self._db_to_amp(D) S = db_to_amp(x=D, gain=self.spec_gain, base=self.base)
S = self._mel_to_linear(S) # Convert back to linear S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear
if self.preemphasis != 0: W = self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
return self._griffin_lim(S**self.power)
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram. """Convert a full scale linear spectrogram output of a network to a melspectrogram.
@ -509,60 +467,22 @@ class AudioProcessor(object):
np.ndarray: Normalized melspectrogram. np.ndarray: Normalized melspectrogram.
""" """
S = self.denormalize(linear_spec) S = self.denormalize(linear_spec)
S = self._db_to_amp(S) S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
S = self._linear_to_mel(np.abs(S)) S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis)
S = self._amp_to_db(S) S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
mel = self.normalize(S) mel = self.normalize(S)
return mel return mel
### STFT and ISTFT ### def _griffin_lim(self, S):
def _stft(self, y: np.ndarray) -> np.ndarray: return griffin_lim(
"""Librosa STFT wrapper. spec=S,
num_iter=self.griffin_lim_iters,
Args:
y (np.ndarray): Audio signal.
Returns:
np.ndarray: Complex number array.
"""
return librosa.stft(
y=y,
n_fft=self.fft_size,
hop_length=self.hop_length, hop_length=self.hop_length,
win_length=self.win_length, win_length=self.win_length,
fft_size=self.fft_size,
pad_mode=self.stft_pad_mode, pad_mode=self.stft_pad_mode,
window="hann",
center=True,
) )
def _istft(self, y: np.ndarray) -> np.ndarray:
"""Librosa iSTFT wrapper."""
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
def _griffin_lim(self, S):
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
try:
S_complex = np.abs(S).astype(np.complex)
except AttributeError: # np.complex is deprecated since numpy 1.20.0
S_complex = np.abs(S).astype(complex)
y = self._istft(S_complex * angles)
if not np.isfinite(y).all():
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
return np.array([0.0])
for _ in range(self.griffin_lim_iters):
angles = np.exp(1j * np.angle(self._stft(y)))
y = self._istft(S_complex * angles)
return y
def compute_stft_paddings(self, x, pad_sides=1):
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
(first and final frames)"""
assert pad_sides in (1, 2)
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
if pad_sides == 1:
return 0, pad
return pad // 2, pad // 2 + pad % 2
def compute_f0(self, x: np.ndarray) -> np.ndarray: def compute_f0(self, x: np.ndarray) -> np.ndarray:
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
@ -581,8 +501,6 @@ class AudioProcessor(object):
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
>>> pitch = ap.compute_f0(wav) >>> pitch = ap.compute_f0(wav)
""" """
assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
assert self.pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`."
# align F0 length to the spectrogram length # align F0 length to the spectrogram length
if len(x) % self.hop_length == 0: if len(x) % self.hop_length == 0:
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode) x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
@ -612,21 +530,24 @@ class AudioProcessor(object):
Returns: Returns:
int: Last point without silence. int: Last point without silence.
""" """
window_length = int(self.sample_rate * min_silence_sec) return find_endpoint(
hop_length = int(window_length / 4) wav=wav,
threshold = self._db_to_amp(-self.trim_db) trim_db=self.trim_db,
for x in range(hop_length, len(wav) - window_length, hop_length): sample_rate=self.sample_rate,
if np.max(wav[x : x + window_length]) < threshold: min_silence_sec=min_silence_sec,
return x + hop_length gain=self.spec_gain,
return len(wav) base=self.base,
)
def trim_silence(self, wav): def trim_silence(self, wav):
"""Trim silent parts with a threshold and 0.01 sec margin""" """Trim silent parts with a threshold and 0.01 sec margin"""
margin = int(self.sample_rate * 0.01) return trim_silence(
wav = wav[margin:-margin] wav=wav,
return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[ sample_rate=self.sample_rate,
0 trim_db=self.trim_db,
] win_length=self.win_length,
hop_length=self.hop_length,
)
@staticmethod @staticmethod
def sound_norm(x: np.ndarray) -> np.ndarray: def sound_norm(x: np.ndarray) -> np.ndarray:
@ -638,13 +559,7 @@ class AudioProcessor(object):
Returns: Returns:
np.ndarray: Volume normalized waveform. np.ndarray: Volume normalized waveform.
""" """
return x / abs(x).max() * 0.95 return volume_norm(x=x)
@staticmethod
def _rms_norm(wav, db_level=-27):
r = 10 ** (db_level / 20)
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
return wav * a
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
"""Normalize the volume based on RMS of the signal. """Normalize the volume based on RMS of the signal.
@ -657,9 +572,7 @@ class AudioProcessor(object):
""" """
if db_level is None: if db_level is None:
db_level = self.db_level db_level = self.db_level
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" return rms_volume_norm(x=x, db_level=db_level)
wav = self._rms_norm(x, db_level)
return wav
### save and load ### ### save and load ###
def load_wav(self, filename: str, sr: int = None) -> np.ndarray: def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
@ -674,15 +587,10 @@ class AudioProcessor(object):
Returns: Returns:
np.ndarray: Loaded waveform. np.ndarray: Loaded waveform.
""" """
if self.resample: if sr is not None:
# loading with resampling. It is significantly slower. x = load_wav(filename=filename, sample_rate=sr, resample=True)
x, sr = librosa.load(filename, sr=self.sample_rate)
elif sr is None:
# SF is faster than librosa for loading files
x, sr = sf.read(filename)
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
else: else:
x, sr = librosa.load(filename, sr=sr) x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample)
if self.do_trim_silence: if self.do_trim_silence:
try: try:
x = self.trim_silence(x) x = self.trim_silence(x)
@ -723,55 +631,3 @@ class AudioProcessor(object):
filename (str): Path to the wav file. filename (str): Path to the wav file.
""" """
return librosa.get_duration(filename=filename) return librosa.get_duration(filename=filename)
@staticmethod
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
mu = 2**qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
# Quantize signal to the specified number of levels.
signal = (signal + 1) / 2 * mu + 0.5
return np.floor(
signal,
)
@staticmethod
def mulaw_decode(wav, qc):
"""Recovers waveform from quantized values."""
mu = 2**qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
@staticmethod
def encode_16bits(x):
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
@staticmethod
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
"""Quantize a waveform to a given number of bits.
Args:
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
bits (int): Number of quantization bits.
Returns:
np.ndarray: Quantized waveform.
"""
return (x + 1.0) * (2**bits - 1) / 2
@staticmethod
def dequantize(x, bits):
"""Dequantize a waveform from the given number of bits."""
return 2 * x / (2**bits - 1) - 1
def _log(x, base):
if base == 10:
return np.log10(x)
return np.log(x)
def _exp(x, base):
if base == 10:
return np.power(10, x)
return np.exp(x)

View File

@ -94,6 +94,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
use_noise_augment: bool = False use_noise_augment: bool = False
use_cache: bool = True use_cache: bool = True
steps_to_start_discriminator: int = 200000 steps_to_start_discriminator: int = 200000
target_loss: str = "loss_1"
# LOSS PARAMETERS - overrides # LOSS PARAMETERS - overrides
use_stft_loss: bool = True use_stft_loss: bool = True

View File

@ -7,6 +7,7 @@ from coqpit import Coqpit
from tqdm import tqdm from tqdm import tqdm
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
@ -29,7 +30,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
mel = ap.melspectrogram(y) mel = ap.melspectrogram(y)
np.save(mel_path, mel) np.save(mel_path, mel)
if isinstance(config.mode, int): if isinstance(config.mode, int):
quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode) quant = (
mulaw_encode(wav=y, mulaw_qc=config.mode)
if config.model_args.mulaw
else quantize(x=y, quantize_bits=config.mode)
)
np.save(quant_path, quant) np.save(quant_path, quant)

View File

@ -2,6 +2,8 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
class WaveRNNDataset(Dataset): class WaveRNNDataset(Dataset):
""" """
@ -66,7 +68,9 @@ class WaveRNNDataset(Dataset):
x_input = audio x_input = audio
elif isinstance(self.mode, int): elif isinstance(self.mode, int):
x_input = ( x_input = (
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode) mulaw_encode(wav=audio, mulaw_qc=self.mode)
if self.mulaw
else quantize(x=audio, quantize_bits=self.mode)
) )
else: else:
raise RuntimeError("Unknown dataset mode - ", self.mode) raise RuntimeError("Unknown dataset mode - ", self.mode)

View File

@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.tts.utils.visual import plot_spectrogram from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_decode
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss from TTS.vocoder.layers.losses import WaveRNNLoss
@ -399,7 +400,7 @@ class Wavernn(BaseVocoder):
output = output[0] output = output[0]
if self.args.mulaw and isinstance(self.args.mode, int): if self.args.mulaw and isinstance(self.args.mode, int):
output = AudioProcessor.mulaw_decode(output, self.args.mode) output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
# Fade-out at the end to avoid signal cutting out suddenly # Fade-out at the end to avoid signal cutting out suddenly
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length) fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)

View File

@ -13,23 +13,28 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"import sys\n",
"import torch\n",
"import importlib\n", "import importlib\n",
"import numpy as np\n", "import os\n",
"from tqdm import tqdm\n",
"from torch.utils.data import DataLoader\n",
"import soundfile as sf\n",
"import pickle\n", "import pickle\n",
"\n",
"import numpy as np\n",
"import soundfile as sf\n",
"import torch\n",
"from matplotlib import pylab as plt\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm\n",
"\n",
"from TTS.config import load_config\n",
"from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
"from TTS.tts.datasets import load_tts_samples\n",
"from TTS.tts.datasets.dataset import TTSDataset\n", "from TTS.tts.datasets.dataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n", "from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.config import load_config\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.models import setup_model\n", "from TTS.tts.models import setup_model\n",
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n", "from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.audio.numpy_transforms import quantize\n",
"\n", "\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"\n", "\n",
@ -49,11 +54,9 @@
" file_name = wav_file.split('.')[0]\n", " file_name = wav_file.split('.')[0]\n",
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n", " os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n", " os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
" os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n",
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n", " wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
" mel_path = os.path.join(out_path, \"mel\", file_name)\n", " mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n", " return file_name, wavq_path, mel_path"
" return file_name, wavq_path, mel_path, wav_path"
] ]
}, },
{ {
@ -65,14 +68,14 @@
"# Paths and configurations\n", "# Paths and configurations\n",
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n", "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n", "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
"PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
"DATASET = \"ljspeech\"\n", "DATASET = \"ljspeech\"\n",
"METADATA_FILE = \"metadata.csv\"\n", "METADATA_FILE = \"metadata.csv\"\n",
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n", "CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n", "MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
"BATCH_SIZE = 32\n", "BATCH_SIZE = 32\n",
"\n", "\n",
"QUANTIZED_WAV = False\n", "QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n",
"QUANTIZE_BIT = None\n",
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n", "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
"\n", "\n",
"# Check CUDA availability\n", "# Check CUDA availability\n",
@ -80,10 +83,10 @@
"print(\" > CUDA enabled: \", use_cuda)\n", "print(\" > CUDA enabled: \", use_cuda)\n",
"\n", "\n",
"# Load the configuration\n", "# Load the configuration\n",
"dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
"C = load_config(CONFIG_PATH)\n", "C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n", "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n", "ap = AudioProcessor(**C.audio)"
"print(C['r'])"
] ]
}, },
{ {
@ -92,12 +95,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# If the vocabulary was passed, replace the default\n", "# Initialize the tokenizer\n",
"if 'characters' in C and C['characters']:\n", "tokenizer, C = TTSTokenizer.init_from_config(C)\n",
" symbols, phonemes = make_symbols(**C.characters)\n",
"\n", "\n",
"# Load the model\n", "# Load the model\n",
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
"# TODO: multiple speakers\n", "# TODO: multiple speakers\n",
"model = setup_model(C)\n", "model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)" "model.load_checkpoint(C, MODEL_FILE, eval=True)"
@ -109,42 +110,21 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Load the preprocessor based on the dataset\n", "# Load data instances\n",
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n", "meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n", "meta_data = meta_data_train + meta_data_eval\n",
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n", "\n",
"dataset = TTSDataset(\n", "dataset = TTSDataset(\n",
" C,\n", " outputs_per_step=C[\"r\"],\n",
" C.text_cleaner,\n", " compute_linear_spec=False,\n",
" False,\n", " ap=ap,\n",
" ap,\n", " samples=meta_data,\n",
" meta_data,\n", " tokenizer=tokenizer,\n",
" characters=C.get('characters', None),\n", " phoneme_cache_path=PHONEME_CACHE_PATH,\n",
" use_phonemes=C.use_phonemes,\n",
" phoneme_cache_path=C.phoneme_cache_path,\n",
" enable_eos_bos=C.enable_eos_bos_chars,\n",
")\n", ")\n",
"loader = DataLoader(\n", "loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n", " dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")\n" ")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Create log file\n",
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"log_file = open(log_file_path, \"w\")"
] ]
}, },
{ {
@ -160,26 +140,33 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Start processing with a progress bar\n", "# Start processing with a progress bar\n",
"with torch.no_grad():\n", "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
" for data in tqdm(loader, desc=\"Processing\"):\n", " for data in tqdm(loader, desc=\"Processing\"):\n",
" try:\n", " try:\n",
" # setup input data\n",
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
"\n",
" # dispatch data to GPU\n", " # dispatch data to GPU\n",
" if use_cuda:\n", " if use_cuda:\n",
" text_input = text_input.cuda()\n", " data[\"token_id\"] = data[\"token_id\"].cuda()\n",
" text_lengths = text_lengths.cuda()\n", " data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
" mel_input = mel_input.cuda()\n", " data[\"mel\"] = data[\"mel\"].cuda()\n",
" mel_lengths = mel_lengths.cuda()\n", " data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
"\n", "\n",
" mask = sequence_mask(text_lengths)\n", " mask = sequence_mask(data[\"token_id_lengths\"])\n",
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n", " outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
" mel_outputs = outputs[\"decoder_outputs\"]\n",
" postnet_outputs = outputs[\"model_outputs\"]\n",
"\n", "\n",
" # compute loss\n", " # compute loss\n",
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n", " loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n", " loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" losses.append(loss.item())\n", " losses.append(loss.item())\n",
" postnet_losses.append(loss_postnet.item())\n", " postnet_losses.append(loss_postnet.item())\n",
"\n", "\n",
@ -193,28 +180,27 @@
" postnet_outputs = torch.stack(mel_specs)\n", " postnet_outputs = torch.stack(mel_specs)\n",
" elif C.model == \"Tacotron2\":\n", " elif C.model == \"Tacotron2\":\n",
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n", " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
" alignments = alignments.detach().cpu().numpy()\n", " alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
"\n", "\n",
" if not DRY_RUN:\n", " if not DRY_RUN:\n",
" for idx in range(text_input.shape[0]):\n", " for idx in range(data[\"token_id\"].shape[0]):\n",
" wav_file_path = item_idx[idx]\n", " wav_file_path = data[\"item_idxs\"][idx]\n",
" wav = ap.load_wav(wav_file_path)\n", " wav = ap.load_wav(wav_file_path)\n",
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", " file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
" file_idxs.append(file_name)\n", " file_idxs.append(file_name)\n",
"\n", "\n",
" # quantize and save wav\n", " # quantize and save wav\n",
" if QUANTIZED_WAV:\n", " if QUANTIZE_BITS > 0:\n",
" wavq = ap.quantize(wav)\n", " wavq = quantize(wav, QUANTIZE_BITS)\n",
" np.save(wavq_path, wavq)\n", " np.save(wavq_path, wavq)\n",
"\n", "\n",
" # save TTS mel\n", " # save TTS mel\n",
" mel = postnet_outputs[idx]\n", " mel = postnet_outputs[idx]\n",
" mel_length = mel_lengths[idx]\n", " mel_length = data[\"mel_lengths\"][idx]\n",
" mel = mel[:mel_length, :].T\n", " mel = mel[:mel_length, :].T\n",
" np.save(mel_path, mel)\n", " np.save(mel_path, mel)\n",
"\n", "\n",
" metadata.append([wav_file_path, mel_path])\n", " metadata.append([wav_file_path, mel_path])\n",
"\n",
" except Exception as e:\n", " except Exception as e:\n",
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n", " log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
"\n", "\n",
@ -224,35 +210,20 @@
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n", " log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n", " log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
"\n", "\n",
"# Close the log file\n",
"log_file.close()\n",
"\n",
"# For wavernn\n", "# For wavernn\n",
"if not DRY_RUN:\n", "if not DRY_RUN:\n",
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n", " pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
"\n", "\n",
"# For pwgan\n", "# For pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n", " for wav_file_path, mel_path in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n", " f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
"\n", "\n",
"# Print mean losses\n", "# Print mean losses\n",
"print(f\"Mean Loss: {mean_loss}\")\n", "print(f\"Mean Loss: {mean_loss}\")\n",
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")" "print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@ -267,7 +238,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"idx = 1\n", "idx = 1\n",
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape" "ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
] ]
}, },
{ {
@ -276,10 +247,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import soundfile as sf\n", "wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
"wav, sr = sf.read(item_idx[idx])\n", "mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
"mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n", "mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n", "mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)" "print(mel_truth.shape)"
] ]
@ -291,7 +261,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# plot posnet output\n", "# plot posnet output\n",
"print(mel_postnet[:mel_lengths[idx], :].shape)\n", "print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)" "plot_spectrogram(mel_postnet, ap)"
] ]
}, },
@ -324,10 +294,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# postnet, decoder diff\n", "# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n",
"mel_diff = mel_decoder - mel_postnet\n", "mel_diff = mel_decoder - mel_postnet\n",
"plt.figure(figsize=(16, 10))\n", "plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ]
@ -339,10 +308,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# PLOT GT SPECTROGRAM diff\n", "# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
"mel_diff2 = mel_truth.T - mel_decoder\n", "mel_diff2 = mel_truth.T - mel_decoder\n",
"plt.figure(figsize=(16, 10))\n", "plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ]
@ -354,21 +322,13 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# PLOT GT SPECTROGRAM diff\n", "# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
"mel = postnet_outputs[idx]\n", "mel = postnet_outputs[idx]\n",
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n", "mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
"plt.figure(figsize=(16, 10))\n", "plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

View File

@ -1,33 +1,33 @@
# core deps # core deps
numpy==1.22.0;python_version<="3.10" numpy==1.22.0;python_version<="3.10"
numpy==1.24.3;python_version>"3.10" numpy>=1.24.3;python_version>"3.10"
cython==0.29.30 cython>=0.29.30
scipy>=1.11.2 scipy>=1.11.2
torch>=2.1 torch>=2.1
torchaudio torchaudio
soundfile==0.12.* soundfile>=0.12.0
librosa==0.10.* librosa>=0.10.0
scikit-learn==1.3.0 scikit-learn>=1.3.0
numba==0.55.1;python_version<"3.9" numba==0.55.1;python_version<"3.9"
numba==0.57.0;python_version>="3.9" numba>=0.57.0;python_version>="3.9"
inflect==5.6.* inflect>=5.6.0
tqdm==4.64.* tqdm>=4.64.1
anyascii==0.3.* anyascii>=0.3.0
pyyaml==6.* pyyaml>=6.0
fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail
aiohttp==3.8.* aiohttp>=3.8.1
packaging==23.1 packaging>=23.1
# deps for examples # deps for examples
flask==2.* flask>=2.0.1
# deps for inference # deps for inference
pysbd==0.3.4 pysbd>=0.3.4
# deps for notebooks # deps for notebooks
umap-learn==0.5.* umap-learn>=0.5.1
pandas>=1.4,<2.0 pandas>=1.4,<2.0
# deps for training # deps for training
matplotlib==3.7.* matplotlib>=3.7.0
# coqui stack # coqui stack
trainer trainer>=0.0.32
# config management # config management
coqpit>=0.0.16 coqpit>=0.0.16
# chinese g2p deps # chinese g2p deps
@ -46,11 +46,11 @@ bangla
bnnumerizer bnnumerizer
bnunicodenormalizer bnunicodenormalizer
#deps for tortoise #deps for tortoise
k_diffusion einops>=0.6.0
einops==0.6.* transformers>=4.33.0
transformers==4.33.*
#deps for bark #deps for bark
encodec==0.1.* encodec>=0.1.1
# deps for XTTS # deps for XTTS
unidecode==1.3.* unidecode>=1.3.2
num2words num2words
spacy[ja]>=3

View File

@ -5,6 +5,7 @@ import torch
from tests import get_tests_input_path, get_tests_output_path, get_tests_path from tests import get_tests_input_path, get_tests_output_path, get_tests_path
from TTS.config import BaseAudioConfig from TTS.config import BaseAudioConfig
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import stft
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
TESTS_PATH = get_tests_path() TESTS_PATH = get_tests_path()
@ -21,7 +22,7 @@ def test_torch_stft():
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length) torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
# librosa stft # librosa stft
wav = ap.load_wav(WAV_FILE) wav = ap.load_wav(WAV_FILE)
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length))
# torch stft # torch stft
wav = torch.from_numpy(wav[None, :]).float() wav = torch.from_numpy(wav[None, :]).float()
M_torch = torch_stft(wav) M_torch = torch_stft(wav)

View File

@ -186,7 +186,7 @@ def test_xtts_v2_streaming():
"en", "en",
gpt_cond_latent, gpt_cond_latent,
speaker_embedding, speaker_embedding,
speed=1.5 speed=1.5,
) )
wav_chuncks = [] wav_chuncks = []
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
@ -198,7 +198,7 @@ def test_xtts_v2_streaming():
"en", "en",
gpt_cond_latent, gpt_cond_latent,
speaker_embedding, speaker_embedding,
speed=0.66 speed=0.66,
) )
wav_chuncks = [] wav_chuncks = []
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):