mirror of https://github.com/coqui-ai/TTS.git
commit
2211ba267a
|
@ -10,7 +10,7 @@ jobs:
|
|||
build-sdist:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
- name: Verify tag matches version
|
||||
run: |
|
||||
set -ex
|
||||
|
@ -38,7 +38,7 @@ jobs:
|
|||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
|
|
@ -3,14 +3,14 @@
|
|||
"multilingual": {
|
||||
"multi-dataset": {
|
||||
"xtts_v2": {
|
||||
"description": "XTTS-v2 by Coqui with 16 languages.",
|
||||
"description": "XTTS-v2.0.2 by Coqui with 16 languages.",
|
||||
"hf_url": [
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
|
||||
],
|
||||
"model_hash": "6a09d1ad43896f06041ed8195956c9698f13b6189dc80f1c74bdc2b8e8d15324",
|
||||
"model_hash": "5ce0502bfe3bc88dc8d9312b12a7558c",
|
||||
"default_vocoder": null,
|
||||
"commit": "480a6cdf7",
|
||||
"license": "CPML",
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.20.5
|
||||
0.20.6
|
||||
|
|
|
@ -15,6 +15,7 @@ from TTS.tts.models import setup_model
|
|||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import quantize
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
@ -159,7 +160,7 @@ def inference(
|
|||
|
||||
|
||||
def extract_spectrograms(
|
||||
data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"
|
||||
data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt"
|
||||
):
|
||||
model.eval()
|
||||
export_metadata = []
|
||||
|
@ -196,8 +197,8 @@ def extract_spectrograms(
|
|||
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
|
||||
|
||||
# quantize and save wav
|
||||
if quantized_wav:
|
||||
wavq = ap.quantize(wav)
|
||||
if quantize_bits > 0:
|
||||
wavq = quantize(wav, quantize_bits)
|
||||
np.save(wavq_path, wavq)
|
||||
|
||||
# save TTS mel
|
||||
|
@ -263,7 +264,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model,
|
||||
ap,
|
||||
args.output_path,
|
||||
quantized_wav=args.quantized,
|
||||
quantize_bits=args.quantize_bits,
|
||||
save_audio=args.save_audio,
|
||||
debug=args.debug,
|
||||
metada_name="metada.txt",
|
||||
|
@ -277,7 +278,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
|
||||
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
||||
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
||||
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
|
||||
parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -13,12 +13,18 @@ import math
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch as th
|
||||
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
||||
|
||||
K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m}
|
||||
try:
|
||||
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
||||
|
||||
K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m}
|
||||
except ImportError:
|
||||
K_DIFFUSION_SAMPLERS = None
|
||||
|
||||
|
||||
SAMPLERS = ["dpm++2m", "p", "ddim"]
|
||||
|
||||
|
||||
|
@ -531,6 +537,8 @@ class GaussianDiffusion:
|
|||
if self.conditioning_free is not True:
|
||||
raise RuntimeError("cond_free must be true")
|
||||
with tqdm(total=self.num_timesteps) as pbar:
|
||||
if K_DIFFUSION_SAMPLERS is None:
|
||||
raise ModuleNotFoundError("Install k_diffusion for using k_diffusion samplers")
|
||||
return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
|
||||
else:
|
||||
raise RuntimeError("sampler not impl")
|
||||
|
|
|
@ -441,7 +441,9 @@ class GPT(nn.Module):
|
|||
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
||||
|
||||
# Pad mel codes with stop_audio_token
|
||||
audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
|
||||
audio_codes = self.set_mel_padding(
|
||||
audio_codes, code_lengths - 3
|
||||
) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
|
||||
|
||||
# Build input and target tensors
|
||||
# Prepend start token to inputs and append stop token to targets
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
from functools import cached_property
|
||||
|
||||
import pypinyin
|
||||
|
@ -8,10 +8,66 @@ import torch
|
|||
from hangul_romanize import Transliter
|
||||
from hangul_romanize.rule import academic
|
||||
from num2words import num2words
|
||||
from spacy.lang.ar import Arabic
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.es import Spanish
|
||||
from spacy.lang.ja import Japanese
|
||||
from spacy.lang.zh import Chinese
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||
|
||||
|
||||
def get_spacy_lang(lang):
|
||||
if lang == "zh":
|
||||
return Chinese()
|
||||
elif lang == "ja":
|
||||
return Japanese()
|
||||
elif lang == "ar":
|
||||
return Arabic()
|
||||
elif lang == "es":
|
||||
return Spanish()
|
||||
else:
|
||||
# For most languages, Enlish does the job
|
||||
return English()
|
||||
|
||||
|
||||
def split_sentence(text, lang, text_split_length=250):
|
||||
"""Preprocess the input text"""
|
||||
text_splits = []
|
||||
if text_split_length is not None and len(text) >= text_split_length:
|
||||
text_splits.append("")
|
||||
nlp = get_spacy_lang(lang)
|
||||
nlp.add_pipe("sentencizer")
|
||||
doc = nlp(text)
|
||||
for sentence in doc.sents:
|
||||
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
|
||||
# if the last sentence + the current sentence is less than the text_split_length
|
||||
# then add the current sentence to the last sentence
|
||||
text_splits[-1] += " " + str(sentence)
|
||||
text_splits[-1] = text_splits[-1].lstrip()
|
||||
elif len(str(sentence)) > text_split_length:
|
||||
# if the current sentence is greater than the text_split_length
|
||||
for line in textwrap.wrap(
|
||||
str(sentence),
|
||||
width=text_split_length,
|
||||
drop_whitespace=True,
|
||||
break_on_hyphens=False,
|
||||
tabsize=1,
|
||||
):
|
||||
text_splits.append(str(line))
|
||||
else:
|
||||
text_splits.append(str(sentence))
|
||||
|
||||
if len(text_splits) > 1:
|
||||
if text_splits[0] == "":
|
||||
del text_splits[0]
|
||||
else:
|
||||
text_splits = [text.lstrip()]
|
||||
|
||||
return text_splits
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
|
@ -115,7 +171,7 @@ _abbreviations = {
|
|||
# There are not many common abbreviations in Arabic as in English.
|
||||
]
|
||||
],
|
||||
"zh-cn": [
|
||||
"zh": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
||||
|
@ -280,7 +336,7 @@ _symbols_multilingual = {
|
|||
("°", " درجة "),
|
||||
]
|
||||
],
|
||||
"zh-cn": [
|
||||
"zh": [
|
||||
# Chinese
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
|
@ -464,7 +520,7 @@ def _expand_number(m, lang="en"):
|
|||
|
||||
|
||||
def expand_numbers_multilingual(text, lang="en"):
|
||||
if lang == "zh" or lang == "zh-cn":
|
||||
if lang == "zh":
|
||||
text = zh_num2words()(text)
|
||||
else:
|
||||
if lang in ["en", "ru"]:
|
||||
|
@ -525,7 +581,7 @@ def japanese_cleaners(text, katsu):
|
|||
return text
|
||||
|
||||
|
||||
def korean_cleaners(text):
|
||||
def korean_transliterate(text):
|
||||
r = Transliter(academic)
|
||||
return r.translit(text)
|
||||
|
||||
|
@ -546,7 +602,7 @@ class VoiceBpeTokenizer:
|
|||
"it": 213,
|
||||
"pt": 203,
|
||||
"pl": 224,
|
||||
"zh-cn": 82,
|
||||
"zh": 82,
|
||||
"ar": 166,
|
||||
"cs": 186,
|
||||
"ru": 182,
|
||||
|
@ -564,6 +620,7 @@ class VoiceBpeTokenizer:
|
|||
return cutlet.Cutlet()
|
||||
|
||||
def check_input_length(self, txt, lang):
|
||||
lang = lang.split("-")[0] # remove the region
|
||||
limit = self.char_limits.get(lang, 250)
|
||||
if len(txt) > limit:
|
||||
print(
|
||||
|
@ -571,21 +628,23 @@ class VoiceBpeTokenizer:
|
|||
)
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
|
||||
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
|
||||
txt = multilingual_cleaners(txt, lang)
|
||||
if lang in {"zh", "zh-cn"}:
|
||||
if lang == "zh":
|
||||
txt = chinese_transliterate(txt)
|
||||
if lang == "ko":
|
||||
txt = korean_transliterate(txt)
|
||||
elif lang == "ja":
|
||||
txt = japanese_cleaners(txt, self.katsu)
|
||||
elif lang == "ko":
|
||||
txt = korean_cleaners(txt)
|
||||
else:
|
||||
raise NotImplementedError(f"Language '{lang}' is not supported.")
|
||||
return txt
|
||||
|
||||
def encode(self, txt, lang):
|
||||
lang = lang.split("-")[0] # remove the region
|
||||
self.check_input_length(txt, lang)
|
||||
txt = self.preprocess_text(txt, lang)
|
||||
lang = "zh-cn" if lang == "zh" else lang
|
||||
txt = f"[{lang}]{txt}"
|
||||
txt = txt.replace(" ", "[SPACE]")
|
||||
return self.tokenizer.encode(txt).ids
|
||||
|
@ -682,8 +741,8 @@ def test_expand_numbers_multilingual():
|
|||
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
|
||||
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
|
||||
# Chinese (Simplified)
|
||||
("在12.5秒内", "在十二点五秒内", "zh-cn"),
|
||||
("有50名士兵", "有五十名士兵", "zh-cn"),
|
||||
("在12.5秒内", "在十二点五秒内", "zh"),
|
||||
("有50名士兵", "有五十名士兵", "zh"),
|
||||
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
|
||||
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
|
||||
# Turkish
|
||||
|
@ -764,7 +823,7 @@ def test_symbols_multilingual():
|
|||
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
|
||||
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
|
||||
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
|
||||
("我的电量为 14%", "我的电量为 14 百分之", "zh-cn"),
|
||||
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
|
||||
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
|
||||
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
|
||||
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
|
||||
|
|
|
@ -318,9 +318,10 @@ class GPTTrainer(BaseTTS):
|
|||
batch["cond_idxs"] = None
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
||||
# guarante that dvae will be in eval mode after .train() on evaluation end
|
||||
self.dvae = self.dvae.eval()
|
||||
def on_train_epoch_start(self, trainer):
|
||||
trainer.model.eval() # the whole model to eval
|
||||
# put gpt model in training mode
|
||||
trainer.model.xtts.gpt.train()
|
||||
|
||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||
# ignore similarities.pth on clearml save/upload
|
||||
|
|
|
@ -10,7 +10,7 @@ from coqpit import Coqpit
|
|||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
@ -396,7 +396,7 @@ class Xtts(BaseTTS):
|
|||
inference with config
|
||||
"""
|
||||
assert (
|
||||
language in self.config.languages
|
||||
"zh-cn" if language == "zh" else language in self.config.languages
|
||||
), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}"
|
||||
# Use generally found best tuning knobs for generation.
|
||||
settings = {
|
||||
|
@ -420,9 +420,9 @@ class Xtts(BaseTTS):
|
|||
ref_audio_path,
|
||||
language,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
temperature=0.75,
|
||||
length_penalty=1.0,
|
||||
repetition_penalty=10.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
|
@ -502,71 +502,76 @@ class Xtts(BaseTTS):
|
|||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
temperature=0.75,
|
||||
length_penalty=1.0,
|
||||
repetition_penalty=10.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
speed=1.0,
|
||||
enable_text_splitting=False,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
language = language.split("-")[0] # remove the country code
|
||||
length_scale = 1.0 / max(speed, 0.05)
|
||||
text = text.strip().lower()
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
if enable_text_splitting:
|
||||
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||
else:
|
||||
text = [text]
|
||||
|
||||
# print(" > Input text: ", text)
|
||||
# print(" > Input text preprocessed: ",self.tokenizer.preprocess_text(text, language))
|
||||
# print(" > Input tokens: ", text_tokens)
|
||||
# print(" > Decoded text: ", self.tokenizer.decode(text_tokens[0].cpu().numpy()))
|
||||
assert (
|
||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||
wavs = []
|
||||
gpt_latents_list = []
|
||||
for sent in text:
|
||||
sent = sent.strip().lower()
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
gpt_codes = self.gpt.generate(
|
||||
cond_latents=gpt_cond_latent,
|
||||
text_inputs=text_tokens,
|
||||
input_tokens=None,
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.gpt_batch_size,
|
||||
num_beams=num_beams,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
output_attentions=False,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
expected_output_len = torch.tensor(
|
||||
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
||||
)
|
||||
assert (
|
||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||
|
||||
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
||||
gpt_latents = self.gpt(
|
||||
text_tokens,
|
||||
text_len,
|
||||
gpt_codes,
|
||||
expected_output_len,
|
||||
cond_latents=gpt_cond_latent,
|
||||
return_attentions=False,
|
||||
return_latent=True,
|
||||
)
|
||||
with torch.no_grad():
|
||||
gpt_codes = self.gpt.generate(
|
||||
cond_latents=gpt_cond_latent,
|
||||
text_inputs=text_tokens,
|
||||
input_tokens=None,
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.gpt_batch_size,
|
||||
num_beams=num_beams,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
output_attentions=False,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
expected_output_len = torch.tensor(
|
||||
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
||||
)
|
||||
|
||||
if length_scale != 1.0:
|
||||
gpt_latents = F.interpolate(
|
||||
gpt_latents.transpose(1, 2),
|
||||
scale_factor=length_scale,
|
||||
mode="linear"
|
||||
).transpose(1, 2)
|
||||
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
||||
gpt_latents = self.gpt(
|
||||
text_tokens,
|
||||
text_len,
|
||||
gpt_codes,
|
||||
expected_output_len,
|
||||
cond_latents=gpt_cond_latent,
|
||||
return_attentions=False,
|
||||
return_latent=True,
|
||||
)
|
||||
|
||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
if length_scale != 1.0:
|
||||
gpt_latents = F.interpolate(
|
||||
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||
).transpose(1, 2)
|
||||
|
||||
gpt_latents_list.append(gpt_latents.cpu())
|
||||
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
|
||||
|
||||
return {
|
||||
"wav": wav.cpu().numpy().squeeze(),
|
||||
"gpt_latents": gpt_latents,
|
||||
"wav": torch.cat(wavs, dim=0).numpy(),
|
||||
"gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(),
|
||||
"speaker_embedding": speaker_embedding,
|
||||
}
|
||||
|
||||
|
@ -606,66 +611,76 @@ class Xtts(BaseTTS):
|
|||
stream_chunk_size=20,
|
||||
overlap_wav_len=1024,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
temperature=0.75,
|
||||
length_penalty=1.0,
|
||||
repetition_penalty=10.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
speed=1.0,
|
||||
enable_text_splitting=False,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
language = language.split("-")[0] # remove the country code
|
||||
length_scale = 1.0 / max(speed, 0.05)
|
||||
text = text.strip().lower()
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
if enable_text_splitting:
|
||||
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||
else:
|
||||
text = [text]
|
||||
|
||||
fake_inputs = self.gpt.compute_embeddings(
|
||||
gpt_cond_latent.to(self.device),
|
||||
text_tokens,
|
||||
)
|
||||
gpt_generator = self.gpt.get_generator(
|
||||
fake_inputs=fake_inputs,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
num_beams=1,
|
||||
num_return_sequences=1,
|
||||
length_penalty=float(length_penalty),
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
output_attentions=False,
|
||||
output_hidden_states=True,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
for sent in text:
|
||||
sent = sent.strip().lower()
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
last_tokens = []
|
||||
all_latents = []
|
||||
wav_gen_prev = None
|
||||
wav_overlap = None
|
||||
is_end = False
|
||||
assert (
|
||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||
|
||||
while not is_end:
|
||||
try:
|
||||
x, latent = next(gpt_generator)
|
||||
last_tokens += [x]
|
||||
all_latents += [latent]
|
||||
except StopIteration:
|
||||
is_end = True
|
||||
fake_inputs = self.gpt.compute_embeddings(
|
||||
gpt_cond_latent.to(self.device),
|
||||
text_tokens,
|
||||
)
|
||||
gpt_generator = self.gpt.get_generator(
|
||||
fake_inputs=fake_inputs,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
num_beams=1,
|
||||
num_return_sequences=1,
|
||||
length_penalty=float(length_penalty),
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
output_attentions=False,
|
||||
output_hidden_states=True,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||
if length_scale != 1.0:
|
||||
gpt_latents = F.interpolate(
|
||||
gpt_latents.transpose(1, 2),
|
||||
scale_factor=length_scale,
|
||||
mode="linear"
|
||||
).transpose(1, 2)
|
||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||
)
|
||||
last_tokens = []
|
||||
yield wav_chunk
|
||||
last_tokens = []
|
||||
all_latents = []
|
||||
wav_gen_prev = None
|
||||
wav_overlap = None
|
||||
is_end = False
|
||||
|
||||
while not is_end:
|
||||
try:
|
||||
x, latent = next(gpt_generator)
|
||||
last_tokens += [x]
|
||||
all_latents += [latent]
|
||||
except StopIteration:
|
||||
is_end = True
|
||||
|
||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||
if length_scale != 1.0:
|
||||
gpt_latents = F.interpolate(
|
||||
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||
).transpose(1, 2)
|
||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||
)
|
||||
last_tokens = []
|
||||
yield wav_chunk
|
||||
|
||||
def forward(self):
|
||||
raise NotImplementedError(
|
||||
|
|
|
@ -201,7 +201,6 @@ def stft(
|
|||
def istft(
|
||||
*,
|
||||
y: np.ndarray = None,
|
||||
fft_size: int = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
window: str = "hann",
|
||||
|
|
|
@ -5,10 +5,26 @@ import librosa
|
|||
import numpy as np
|
||||
import scipy.io.wavfile
|
||||
import scipy.signal
|
||||
import soundfile as sf
|
||||
|
||||
from TTS.tts.utils.helpers import StandardScaler
|
||||
from TTS.utils.audio.numpy_transforms import compute_f0
|
||||
from TTS.utils.audio.numpy_transforms import (
|
||||
amp_to_db,
|
||||
build_mel_basis,
|
||||
compute_f0,
|
||||
db_to_amp,
|
||||
deemphasis,
|
||||
find_endpoint,
|
||||
griffin_lim,
|
||||
load_wav,
|
||||
mel_to_spec,
|
||||
millisec_to_length,
|
||||
preemphasis,
|
||||
rms_volume_norm,
|
||||
spec_to_mel,
|
||||
stft,
|
||||
trim_silence,
|
||||
volume_norm,
|
||||
)
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
|
||||
|
@ -200,7 +216,9 @@ class AudioProcessor(object):
|
|||
# setup stft parameters
|
||||
if hop_length is None:
|
||||
# compute stft parameters from given time values
|
||||
self.hop_length, self.win_length = self._stft_parameters()
|
||||
self.win_length, self.hop_length = millisec_to_length(
|
||||
frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate
|
||||
)
|
||||
else:
|
||||
# use stft parameters from config file
|
||||
self.hop_length = hop_length
|
||||
|
@ -215,8 +233,13 @@ class AudioProcessor(object):
|
|||
for key, value in members.items():
|
||||
print(" | > {}:{}".format(key, value))
|
||||
# create spectrogram utils
|
||||
self.mel_basis = self._build_mel_basis()
|
||||
self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis())
|
||||
self.mel_basis = build_mel_basis(
|
||||
sample_rate=self.sample_rate,
|
||||
fft_size=self.fft_size,
|
||||
num_mels=self.num_mels,
|
||||
mel_fmax=self.mel_fmax,
|
||||
mel_fmin=self.mel_fmin,
|
||||
)
|
||||
# setup scaler
|
||||
if stats_path and signal_norm:
|
||||
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
|
||||
|
@ -232,35 +255,6 @@ class AudioProcessor(object):
|
|||
return AudioProcessor(verbose=verbose, **config.audio)
|
||||
return AudioProcessor(verbose=verbose, **config)
|
||||
|
||||
### setting up the parameters ###
|
||||
def _build_mel_basis(
|
||||
self,
|
||||
) -> np.ndarray:
|
||||
"""Build melspectrogram basis.
|
||||
|
||||
Returns:
|
||||
np.ndarray: melspectrogram basis.
|
||||
"""
|
||||
if self.mel_fmax is not None:
|
||||
assert self.mel_fmax <= self.sample_rate // 2
|
||||
return librosa.filters.mel(
|
||||
sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
||||
)
|
||||
|
||||
def _stft_parameters(
|
||||
self,
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute the real STFT parameters from the time values.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: hop length and window length for STFT.
|
||||
"""
|
||||
factor = self.frame_length_ms / self.frame_shift_ms
|
||||
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
|
||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||
win_length = int(hop_length * factor)
|
||||
return hop_length, win_length
|
||||
|
||||
### normalization ###
|
||||
def normalize(self, S: np.ndarray) -> np.ndarray:
|
||||
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
|
||||
|
@ -386,31 +380,6 @@ class AudioProcessor(object):
|
|||
self.linear_scaler = StandardScaler()
|
||||
self.linear_scaler.set_stats(linear_mean, linear_std)
|
||||
|
||||
### DB and AMP conversion ###
|
||||
# pylint: disable=no-self-use
|
||||
def _amp_to_db(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Convert amplitude values to decibels.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Amplitude spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decibels spectrogram.
|
||||
"""
|
||||
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def _db_to_amp(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Convert decibels spectrogram to amplitude spectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Decibels spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Amplitude spectrogram.
|
||||
"""
|
||||
return _exp(x / self.spec_gain, self.base)
|
||||
|
||||
### Preemphasis ###
|
||||
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||
|
@ -424,32 +393,13 @@ class AudioProcessor(object):
|
|||
Returns:
|
||||
np.ndarray: Decorrelated audio signal.
|
||||
"""
|
||||
if self.preemphasis == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1, -self.preemphasis], [1], x)
|
||||
return preemphasis(x=x, coef=self.preemphasis)
|
||||
|
||||
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Reverse pre-emphasis."""
|
||||
if self.preemphasis == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1], [1, -self.preemphasis], x)
|
||||
return deemphasis(x=x, coef=self.preemphasis)
|
||||
|
||||
### SPECTROGRAMs ###
|
||||
def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Project a full scale spectrogram to a melspectrogram.
|
||||
|
||||
Args:
|
||||
spectrogram (np.ndarray): Full scale spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Melspectrogram
|
||||
"""
|
||||
return np.dot(self.mel_basis, spectrogram)
|
||||
|
||||
def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray:
|
||||
"""Convert a melspectrogram to full scale spectrogram."""
|
||||
return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec))
|
||||
|
||||
def spectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Compute a spectrogram from a waveform.
|
||||
|
||||
|
@ -460,11 +410,16 @@ class AudioProcessor(object):
|
|||
np.ndarray: Spectrogram.
|
||||
"""
|
||||
if self.preemphasis != 0:
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
else:
|
||||
D = self._stft(y)
|
||||
y = self.apply_preemphasis(y)
|
||||
D = stft(
|
||||
y=y,
|
||||
fft_size=self.fft_size,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
)
|
||||
if self.do_amp_to_db_linear:
|
||||
S = self._amp_to_db(np.abs(D))
|
||||
S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base)
|
||||
else:
|
||||
S = np.abs(D)
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
@ -472,32 +427,35 @@ class AudioProcessor(object):
|
|||
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Compute a melspectrogram from a waveform."""
|
||||
if self.preemphasis != 0:
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
else:
|
||||
D = self._stft(y)
|
||||
y = self.apply_preemphasis(y)
|
||||
D = stft(
|
||||
y=y,
|
||||
fft_size=self.fft_size,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
)
|
||||
S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis)
|
||||
if self.do_amp_to_db_mel:
|
||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
||||
else:
|
||||
S = self._linear_to_mel(np.abs(D))
|
||||
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
||||
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = self.denormalize(spectrogram)
|
||||
S = self._db_to_amp(S)
|
||||
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||
# Reconstruct phase
|
||||
if self.preemphasis != 0:
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||
return self._griffin_lim(S**self.power)
|
||||
W = self._griffin_lim(S**self.power)
|
||||
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||
|
||||
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
D = self.denormalize(mel_spectrogram)
|
||||
S = self._db_to_amp(D)
|
||||
S = self._mel_to_linear(S) # Convert back to linear
|
||||
if self.preemphasis != 0:
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||
return self._griffin_lim(S**self.power)
|
||||
S = db_to_amp(x=D, gain=self.spec_gain, base=self.base)
|
||||
S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear
|
||||
W = self._griffin_lim(S**self.power)
|
||||
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||
|
||||
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||
|
@ -509,60 +467,22 @@ class AudioProcessor(object):
|
|||
np.ndarray: Normalized melspectrogram.
|
||||
"""
|
||||
S = self.denormalize(linear_spec)
|
||||
S = self._db_to_amp(S)
|
||||
S = self._linear_to_mel(np.abs(S))
|
||||
S = self._amp_to_db(S)
|
||||
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||
S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis)
|
||||
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||
mel = self.normalize(S)
|
||||
return mel
|
||||
|
||||
### STFT and ISTFT ###
|
||||
def _stft(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Librosa STFT wrapper.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Audio signal.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Complex number array.
|
||||
"""
|
||||
return librosa.stft(
|
||||
y=y,
|
||||
n_fft=self.fft_size,
|
||||
def _griffin_lim(self, S):
|
||||
return griffin_lim(
|
||||
spec=S,
|
||||
num_iter=self.griffin_lim_iters,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
fft_size=self.fft_size,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
window="hann",
|
||||
center=True,
|
||||
)
|
||||
|
||||
def _istft(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Librosa iSTFT wrapper."""
|
||||
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
|
||||
|
||||
def _griffin_lim(self, S):
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
try:
|
||||
S_complex = np.abs(S).astype(np.complex)
|
||||
except AttributeError: # np.complex is deprecated since numpy 1.20.0
|
||||
S_complex = np.abs(S).astype(complex)
|
||||
y = self._istft(S_complex * angles)
|
||||
if not np.isfinite(y).all():
|
||||
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
|
||||
return np.array([0.0])
|
||||
for _ in range(self.griffin_lim_iters):
|
||||
angles = np.exp(1j * np.angle(self._stft(y)))
|
||||
y = self._istft(S_complex * angles)
|
||||
return y
|
||||
|
||||
def compute_stft_paddings(self, x, pad_sides=1):
|
||||
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
||||
(first and final frames)"""
|
||||
assert pad_sides in (1, 2)
|
||||
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
|
||||
if pad_sides == 1:
|
||||
return 0, pad
|
||||
return pad // 2, pad // 2 + pad % 2
|
||||
|
||||
def compute_f0(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||
|
||||
|
@ -581,8 +501,6 @@ class AudioProcessor(object):
|
|||
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||
>>> pitch = ap.compute_f0(wav)
|
||||
"""
|
||||
assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
||||
assert self.pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`."
|
||||
# align F0 length to the spectrogram length
|
||||
if len(x) % self.hop_length == 0:
|
||||
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
|
||||
|
@ -612,21 +530,24 @@ class AudioProcessor(object):
|
|||
Returns:
|
||||
int: Last point without silence.
|
||||
"""
|
||||
window_length = int(self.sample_rate * min_silence_sec)
|
||||
hop_length = int(window_length / 4)
|
||||
threshold = self._db_to_amp(-self.trim_db)
|
||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||
if np.max(wav[x : x + window_length]) < threshold:
|
||||
return x + hop_length
|
||||
return len(wav)
|
||||
return find_endpoint(
|
||||
wav=wav,
|
||||
trim_db=self.trim_db,
|
||||
sample_rate=self.sample_rate,
|
||||
min_silence_sec=min_silence_sec,
|
||||
gain=self.spec_gain,
|
||||
base=self.base,
|
||||
)
|
||||
|
||||
def trim_silence(self, wav):
|
||||
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
||||
margin = int(self.sample_rate * 0.01)
|
||||
wav = wav[margin:-margin]
|
||||
return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[
|
||||
0
|
||||
]
|
||||
return trim_silence(
|
||||
wav=wav,
|
||||
sample_rate=self.sample_rate,
|
||||
trim_db=self.trim_db,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sound_norm(x: np.ndarray) -> np.ndarray:
|
||||
|
@ -638,13 +559,7 @@ class AudioProcessor(object):
|
|||
Returns:
|
||||
np.ndarray: Volume normalized waveform.
|
||||
"""
|
||||
return x / abs(x).max() * 0.95
|
||||
|
||||
@staticmethod
|
||||
def _rms_norm(wav, db_level=-27):
|
||||
r = 10 ** (db_level / 20)
|
||||
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
|
||||
return wav * a
|
||||
return volume_norm(x=x)
|
||||
|
||||
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
|
||||
"""Normalize the volume based on RMS of the signal.
|
||||
|
@ -657,9 +572,7 @@ class AudioProcessor(object):
|
|||
"""
|
||||
if db_level is None:
|
||||
db_level = self.db_level
|
||||
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
|
||||
wav = self._rms_norm(x, db_level)
|
||||
return wav
|
||||
return rms_volume_norm(x=x, db_level=db_level)
|
||||
|
||||
### save and load ###
|
||||
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
||||
|
@ -674,15 +587,10 @@ class AudioProcessor(object):
|
|||
Returns:
|
||||
np.ndarray: Loaded waveform.
|
||||
"""
|
||||
if self.resample:
|
||||
# loading with resampling. It is significantly slower.
|
||||
x, sr = librosa.load(filename, sr=self.sample_rate)
|
||||
elif sr is None:
|
||||
# SF is faster than librosa for loading files
|
||||
x, sr = sf.read(filename)
|
||||
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
|
||||
if sr is not None:
|
||||
x = load_wav(filename=filename, sample_rate=sr, resample=True)
|
||||
else:
|
||||
x, sr = librosa.load(filename, sr=sr)
|
||||
x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample)
|
||||
if self.do_trim_silence:
|
||||
try:
|
||||
x = self.trim_silence(x)
|
||||
|
@ -723,55 +631,3 @@ class AudioProcessor(object):
|
|||
filename (str): Path to the wav file.
|
||||
"""
|
||||
return librosa.get_duration(filename=filename)
|
||||
|
||||
@staticmethod
|
||||
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
|
||||
mu = 2**qc - 1
|
||||
# wav_abs = np.minimum(np.abs(wav), 1.0)
|
||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
||||
# Quantize signal to the specified number of levels.
|
||||
signal = (signal + 1) / 2 * mu + 0.5
|
||||
return np.floor(
|
||||
signal,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mulaw_decode(wav, qc):
|
||||
"""Recovers waveform from quantized values."""
|
||||
mu = 2**qc - 1
|
||||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def encode_16bits(x):
|
||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
||||
|
||||
@staticmethod
|
||||
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
|
||||
"""Quantize a waveform to a given number of bits.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
|
||||
bits (int): Number of quantization bits.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Quantized waveform.
|
||||
"""
|
||||
return (x + 1.0) * (2**bits - 1) / 2
|
||||
|
||||
@staticmethod
|
||||
def dequantize(x, bits):
|
||||
"""Dequantize a waveform from the given number of bits."""
|
||||
return 2 * x / (2**bits - 1) - 1
|
||||
|
||||
|
||||
def _log(x, base):
|
||||
if base == 10:
|
||||
return np.log10(x)
|
||||
return np.log(x)
|
||||
|
||||
|
||||
def _exp(x, base):
|
||||
if base == 10:
|
||||
return np.power(10, x)
|
||||
return np.exp(x)
|
||||
|
|
|
@ -94,6 +94,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
|
|||
use_noise_augment: bool = False
|
||||
use_cache: bool = True
|
||||
steps_to_start_discriminator: int = 200000
|
||||
target_loss: str = "loss_1"
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
|
|
|
@ -7,6 +7,7 @@ from coqpit import Coqpit
|
|||
from tqdm import tqdm
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||
|
||||
|
||||
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
||||
|
@ -29,7 +30,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
|||
mel = ap.melspectrogram(y)
|
||||
np.save(mel_path, mel)
|
||||
if isinstance(config.mode, int):
|
||||
quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode)
|
||||
quant = (
|
||||
mulaw_encode(wav=y, mulaw_qc=config.mode)
|
||||
if config.model_args.mulaw
|
||||
else quantize(x=y, quantize_bits=config.mode)
|
||||
)
|
||||
np.save(quant_path, quant)
|
||||
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
|||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||
|
||||
|
||||
class WaveRNNDataset(Dataset):
|
||||
"""
|
||||
|
@ -66,7 +68,9 @@ class WaveRNNDataset(Dataset):
|
|||
x_input = audio
|
||||
elif isinstance(self.mode, int):
|
||||
x_input = (
|
||||
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)
|
||||
mulaw_encode(wav=audio, mulaw_qc=self.mode)
|
||||
if self.mulaw
|
||||
else quantize(x=audio, quantize_bits=self.mode)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||
|
@ -399,7 +400,7 @@ class Wavernn(BaseVocoder):
|
|||
output = output[0]
|
||||
|
||||
if self.args.mulaw and isinstance(self.args.mode, int):
|
||||
output = AudioProcessor.mulaw_decode(output, self.args.mode)
|
||||
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
|
||||
|
||||
# Fade-out at the end to avoid signal cutting out suddenly
|
||||
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
|
||||
|
|
|
@ -13,23 +13,28 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import torch\n",
|
||||
"import importlib\n",
|
||||
"import numpy as np\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import soundfile as sf\n",
|
||||
"import os\n",
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import soundfile as sf\n",
|
||||
"import torch\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"from TTS.config import load_config\n",
|
||||
"from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
|
||||
"from TTS.tts.datasets import load_tts_samples\n",
|
||||
"from TTS.tts.datasets.dataset import TTSDataset\n",
|
||||
"from TTS.tts.layers.losses import L1LossMasked\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.config import load_config\n",
|
||||
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
||||
"from TTS.tts.utils.helpers import sequence_mask\n",
|
||||
"from TTS.tts.models import setup_model\n",
|
||||
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n",
|
||||
"from TTS.tts.utils.helpers import sequence_mask\n",
|
||||
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
|
||||
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.utils.audio.numpy_transforms import quantize\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
|
@ -49,11 +54,9 @@
|
|||
" file_name = wav_file.split('.')[0]\n",
|
||||
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
|
||||
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
|
||||
" os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n",
|
||||
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
|
||||
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
|
||||
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
|
||||
" return file_name, wavq_path, mel_path, wav_path"
|
||||
" return file_name, wavq_path, mel_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -65,14 +68,14 @@
|
|||
"# Paths and configurations\n",
|
||||
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
|
||||
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
|
||||
"PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
|
||||
"DATASET = \"ljspeech\"\n",
|
||||
"METADATA_FILE = \"metadata.csv\"\n",
|
||||
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
|
||||
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
|
||||
"BATCH_SIZE = 32\n",
|
||||
"\n",
|
||||
"QUANTIZED_WAV = False\n",
|
||||
"QUANTIZE_BIT = None\n",
|
||||
"QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n",
|
||||
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
||||
"\n",
|
||||
"# Check CUDA availability\n",
|
||||
|
@ -80,10 +83,10 @@
|
|||
"print(\" > CUDA enabled: \", use_cuda)\n",
|
||||
"\n",
|
||||
"# Load the configuration\n",
|
||||
"dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
|
||||
"C = load_config(CONFIG_PATH)\n",
|
||||
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
|
||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
|
||||
"print(C['r'])"
|
||||
"ap = AudioProcessor(**C.audio)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -92,12 +95,10 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# If the vocabulary was passed, replace the default\n",
|
||||
"if 'characters' in C and C['characters']:\n",
|
||||
" symbols, phonemes = make_symbols(**C.characters)\n",
|
||||
"# Initialize the tokenizer\n",
|
||||
"tokenizer, C = TTSTokenizer.init_from_config(C)\n",
|
||||
"\n",
|
||||
"# Load the model\n",
|
||||
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
|
||||
"# TODO: multiple speakers\n",
|
||||
"model = setup_model(C)\n",
|
||||
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
|
||||
|
@ -109,42 +110,21 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the preprocessor based on the dataset\n",
|
||||
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
|
||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
||||
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
|
||||
"# Load data instances\n",
|
||||
"meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
|
||||
"meta_data = meta_data_train + meta_data_eval\n",
|
||||
"\n",
|
||||
"dataset = TTSDataset(\n",
|
||||
" C,\n",
|
||||
" C.text_cleaner,\n",
|
||||
" False,\n",
|
||||
" ap,\n",
|
||||
" meta_data,\n",
|
||||
" characters=C.get('characters', None),\n",
|
||||
" use_phonemes=C.use_phonemes,\n",
|
||||
" phoneme_cache_path=C.phoneme_cache_path,\n",
|
||||
" enable_eos_bos=C.enable_eos_bos_chars,\n",
|
||||
" outputs_per_step=C[\"r\"],\n",
|
||||
" compute_linear_spec=False,\n",
|
||||
" ap=ap,\n",
|
||||
" samples=meta_data,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" phoneme_cache_path=PHONEME_CACHE_PATH,\n",
|
||||
")\n",
|
||||
"loader = DataLoader(\n",
|
||||
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize lists for storing results\n",
|
||||
"file_idxs = []\n",
|
||||
"metadata = []\n",
|
||||
"losses = []\n",
|
||||
"postnet_losses = []\n",
|
||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||
"\n",
|
||||
"# Create log file\n",
|
||||
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
||||
"log_file = open(log_file_path, \"w\")"
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -160,26 +140,33 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize lists for storing results\n",
|
||||
"file_idxs = []\n",
|
||||
"metadata = []\n",
|
||||
"losses = []\n",
|
||||
"postnet_losses = []\n",
|
||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||
"\n",
|
||||
"# Start processing with a progress bar\n",
|
||||
"with torch.no_grad():\n",
|
||||
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
||||
"with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
|
||||
" for data in tqdm(loader, desc=\"Processing\"):\n",
|
||||
" try:\n",
|
||||
" # setup input data\n",
|
||||
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
|
||||
"\n",
|
||||
" # dispatch data to GPU\n",
|
||||
" if use_cuda:\n",
|
||||
" text_input = text_input.cuda()\n",
|
||||
" text_lengths = text_lengths.cuda()\n",
|
||||
" mel_input = mel_input.cuda()\n",
|
||||
" mel_lengths = mel_lengths.cuda()\n",
|
||||
" data[\"token_id\"] = data[\"token_id\"].cuda()\n",
|
||||
" data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
|
||||
" data[\"mel\"] = data[\"mel\"].cuda()\n",
|
||||
" data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
|
||||
"\n",
|
||||
" mask = sequence_mask(text_lengths)\n",
|
||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
||||
" mask = sequence_mask(data[\"token_id_lengths\"])\n",
|
||||
" outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
|
||||
" mel_outputs = outputs[\"decoder_outputs\"]\n",
|
||||
" postnet_outputs = outputs[\"model_outputs\"]\n",
|
||||
"\n",
|
||||
" # compute loss\n",
|
||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
||||
" loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
|
||||
" loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
|
||||
" losses.append(loss.item())\n",
|
||||
" postnet_losses.append(loss_postnet.item())\n",
|
||||
"\n",
|
||||
|
@ -193,28 +180,27 @@
|
|||
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||
" elif C.model == \"Tacotron2\":\n",
|
||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||
" alignments = alignments.detach().cpu().numpy()\n",
|
||||
" alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
|
||||
"\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" for idx in range(text_input.shape[0]):\n",
|
||||
" wav_file_path = item_idx[idx]\n",
|
||||
" for idx in range(data[\"token_id\"].shape[0]):\n",
|
||||
" wav_file_path = data[\"item_idxs\"][idx]\n",
|
||||
" wav = ap.load_wav(wav_file_path)\n",
|
||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_idxs.append(file_name)\n",
|
||||
"\n",
|
||||
" # quantize and save wav\n",
|
||||
" if QUANTIZED_WAV:\n",
|
||||
" wavq = ap.quantize(wav)\n",
|
||||
" if QUANTIZE_BITS > 0:\n",
|
||||
" wavq = quantize(wav, QUANTIZE_BITS)\n",
|
||||
" np.save(wavq_path, wavq)\n",
|
||||
"\n",
|
||||
" # save TTS mel\n",
|
||||
" mel = postnet_outputs[idx]\n",
|
||||
" mel_length = mel_lengths[idx]\n",
|
||||
" mel_length = data[\"mel_lengths\"][idx]\n",
|
||||
" mel = mel[:mel_length, :].T\n",
|
||||
" np.save(mel_path, mel)\n",
|
||||
"\n",
|
||||
" metadata.append([wav_file_path, mel_path])\n",
|
||||
"\n",
|
||||
" except Exception as e:\n",
|
||||
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
|
||||
"\n",
|
||||
|
@ -224,35 +210,20 @@
|
|||
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
|
||||
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
|
||||
"\n",
|
||||
"# Close the log file\n",
|
||||
"log_file.close()\n",
|
||||
"\n",
|
||||
"# For wavernn\n",
|
||||
"if not DRY_RUN:\n",
|
||||
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
|
||||
"\n",
|
||||
"# For pwgan\n",
|
||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
||||
" for wav_file_path, mel_path in metadata:\n",
|
||||
" f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
|
||||
"\n",
|
||||
"# Print mean losses\n",
|
||||
"print(f\"Mean Loss: {mean_loss}\")\n",
|
||||
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# for pwgan\n",
|
||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -267,7 +238,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"idx = 1\n",
|
||||
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
|
||||
"ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -276,10 +247,9 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import soundfile as sf\n",
|
||||
"wav, sr = sf.read(item_idx[idx])\n",
|
||||
"mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n",
|
||||
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
|
||||
"wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
|
||||
"mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
|
||||
"mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
|
||||
"mel_truth = ap.melspectrogram(wav)\n",
|
||||
"print(mel_truth.shape)"
|
||||
]
|
||||
|
@ -291,7 +261,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# plot posnet output\n",
|
||||
"print(mel_postnet[:mel_lengths[idx], :].shape)\n",
|
||||
"print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
|
||||
"plot_spectrogram(mel_postnet, ap)"
|
||||
]
|
||||
},
|
||||
|
@ -324,10 +294,9 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# postnet, decoder diff\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
"mel_diff = mel_decoder - mel_postnet\n",
|
||||
"plt.figure(figsize=(16, 10))\n",
|
||||
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
|
||||
"plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||
"plt.colorbar()\n",
|
||||
"plt.tight_layout()"
|
||||
]
|
||||
|
@ -339,10 +308,9 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# PLOT GT SPECTROGRAM diff\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
"mel_diff2 = mel_truth.T - mel_decoder\n",
|
||||
"plt.figure(figsize=(16, 10))\n",
|
||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
|
||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||
"plt.colorbar()\n",
|
||||
"plt.tight_layout()"
|
||||
]
|
||||
|
@ -354,21 +322,13 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# PLOT GT SPECTROGRAM diff\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
"mel = postnet_outputs[idx]\n",
|
||||
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
|
||||
"plt.figure(figsize=(16, 10))\n",
|
||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
|
||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||
"plt.colorbar()\n",
|
||||
"plt.tight_layout()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -1,33 +1,33 @@
|
|||
# core deps
|
||||
numpy==1.22.0;python_version<="3.10"
|
||||
numpy==1.24.3;python_version>"3.10"
|
||||
cython==0.29.30
|
||||
numpy>=1.24.3;python_version>"3.10"
|
||||
cython>=0.29.30
|
||||
scipy>=1.11.2
|
||||
torch>=2.1
|
||||
torchaudio
|
||||
soundfile==0.12.*
|
||||
librosa==0.10.*
|
||||
scikit-learn==1.3.0
|
||||
soundfile>=0.12.0
|
||||
librosa>=0.10.0
|
||||
scikit-learn>=1.3.0
|
||||
numba==0.55.1;python_version<"3.9"
|
||||
numba==0.57.0;python_version>="3.9"
|
||||
inflect==5.6.*
|
||||
tqdm==4.64.*
|
||||
anyascii==0.3.*
|
||||
pyyaml==6.*
|
||||
fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail
|
||||
aiohttp==3.8.*
|
||||
packaging==23.1
|
||||
numba>=0.57.0;python_version>="3.9"
|
||||
inflect>=5.6.0
|
||||
tqdm>=4.64.1
|
||||
anyascii>=0.3.0
|
||||
pyyaml>=6.0
|
||||
fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail
|
||||
aiohttp>=3.8.1
|
||||
packaging>=23.1
|
||||
# deps for examples
|
||||
flask==2.*
|
||||
flask>=2.0.1
|
||||
# deps for inference
|
||||
pysbd==0.3.4
|
||||
pysbd>=0.3.4
|
||||
# deps for notebooks
|
||||
umap-learn==0.5.*
|
||||
umap-learn>=0.5.1
|
||||
pandas>=1.4,<2.0
|
||||
# deps for training
|
||||
matplotlib==3.7.*
|
||||
matplotlib>=3.7.0
|
||||
# coqui stack
|
||||
trainer
|
||||
trainer>=0.0.32
|
||||
# config management
|
||||
coqpit>=0.0.16
|
||||
# chinese g2p deps
|
||||
|
@ -46,11 +46,11 @@ bangla
|
|||
bnnumerizer
|
||||
bnunicodenormalizer
|
||||
#deps for tortoise
|
||||
k_diffusion
|
||||
einops==0.6.*
|
||||
transformers==4.33.*
|
||||
einops>=0.6.0
|
||||
transformers>=4.33.0
|
||||
#deps for bark
|
||||
encodec==0.1.*
|
||||
encodec>=0.1.1
|
||||
# deps for XTTS
|
||||
unidecode==1.3.*
|
||||
unidecode>=1.3.2
|
||||
num2words
|
||||
spacy[ja]>=3
|
|
@ -5,6 +5,7 @@ import torch
|
|||
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||
from TTS.config import BaseAudioConfig
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import stft
|
||||
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
|
||||
|
||||
TESTS_PATH = get_tests_path()
|
||||
|
@ -21,7 +22,7 @@ def test_torch_stft():
|
|||
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
|
||||
# librosa stft
|
||||
wav = ap.load_wav(WAV_FILE)
|
||||
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
|
||||
M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length))
|
||||
# torch stft
|
||||
wav = torch.from_numpy(wav[None, :]).float()
|
||||
M_torch = torch_stft(wav)
|
||||
|
|
|
@ -186,7 +186,7 @@ def test_xtts_v2_streaming():
|
|||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
speed=1.5
|
||||
speed=1.5,
|
||||
)
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
@ -198,7 +198,7 @@ def test_xtts_v2_streaming():
|
|||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
speed=0.66
|
||||
speed=0.66,
|
||||
)
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
|
Loading…
Reference in New Issue