mirror of https://github.com/coqui-ai/TTS.git
commit
2211ba267a
|
@ -10,7 +10,7 @@ jobs:
|
||||||
build-sdist:
|
build-sdist:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Verify tag matches version
|
- name: Verify tag matches version
|
||||||
run: |
|
run: |
|
||||||
set -ex
|
set -ex
|
||||||
|
@ -38,7 +38,7 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.9", "3.10", "3.11"]
|
python-version: ["3.9", "3.10", "3.11"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- uses: actions/setup-python@v2
|
- uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
|
@ -3,14 +3,14 @@
|
||||||
"multilingual": {
|
"multilingual": {
|
||||||
"multi-dataset": {
|
"multi-dataset": {
|
||||||
"xtts_v2": {
|
"xtts_v2": {
|
||||||
"description": "XTTS-v2 by Coqui with 16 languages.",
|
"description": "XTTS-v2.0.2 by Coqui with 16 languages.",
|
||||||
"hf_url": [
|
"hf_url": [
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth",
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth",
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json",
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json",
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
|
||||||
],
|
],
|
||||||
"model_hash": "6a09d1ad43896f06041ed8195956c9698f13b6189dc80f1c74bdc2b8e8d15324",
|
"model_hash": "5ce0502bfe3bc88dc8d9312b12a7558c",
|
||||||
"default_vocoder": null,
|
"default_vocoder": null,
|
||||||
"commit": "480a6cdf7",
|
"commit": "480a6cdf7",
|
||||||
"license": "CPML",
|
"license": "CPML",
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
0.20.5
|
0.20.6
|
||||||
|
|
|
@ -15,6 +15,7 @@ from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.audio.numpy_transforms import quantize
|
||||||
from TTS.utils.generic_utils import count_parameters
|
from TTS.utils.generic_utils import count_parameters
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
@ -159,7 +160,7 @@ def inference(
|
||||||
|
|
||||||
|
|
||||||
def extract_spectrograms(
|
def extract_spectrograms(
|
||||||
data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"
|
data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt"
|
||||||
):
|
):
|
||||||
model.eval()
|
model.eval()
|
||||||
export_metadata = []
|
export_metadata = []
|
||||||
|
@ -196,8 +197,8 @@ def extract_spectrograms(
|
||||||
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
|
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
|
||||||
|
|
||||||
# quantize and save wav
|
# quantize and save wav
|
||||||
if quantized_wav:
|
if quantize_bits > 0:
|
||||||
wavq = ap.quantize(wav)
|
wavq = quantize(wav, quantize_bits)
|
||||||
np.save(wavq_path, wavq)
|
np.save(wavq_path, wavq)
|
||||||
|
|
||||||
# save TTS mel
|
# save TTS mel
|
||||||
|
@ -263,7 +264,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model,
|
model,
|
||||||
ap,
|
ap,
|
||||||
args.output_path,
|
args.output_path,
|
||||||
quantized_wav=args.quantized,
|
quantize_bits=args.quantize_bits,
|
||||||
save_audio=args.save_audio,
|
save_audio=args.save_audio,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
metada_name="metada.txt",
|
metada_name="metada.txt",
|
||||||
|
@ -277,7 +278,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
|
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
|
||||||
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
||||||
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
||||||
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
|
parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
|
||||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -13,12 +13,18 @@ import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch as th
|
import torch as th
|
||||||
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
||||||
|
|
||||||
K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m}
|
try:
|
||||||
|
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
||||||
|
|
||||||
|
K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m}
|
||||||
|
except ImportError:
|
||||||
|
K_DIFFUSION_SAMPLERS = None
|
||||||
|
|
||||||
|
|
||||||
SAMPLERS = ["dpm++2m", "p", "ddim"]
|
SAMPLERS = ["dpm++2m", "p", "ddim"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -531,6 +537,8 @@ class GaussianDiffusion:
|
||||||
if self.conditioning_free is not True:
|
if self.conditioning_free is not True:
|
||||||
raise RuntimeError("cond_free must be true")
|
raise RuntimeError("cond_free must be true")
|
||||||
with tqdm(total=self.num_timesteps) as pbar:
|
with tqdm(total=self.num_timesteps) as pbar:
|
||||||
|
if K_DIFFUSION_SAMPLERS is None:
|
||||||
|
raise ModuleNotFoundError("Install k_diffusion for using k_diffusion samplers")
|
||||||
return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
|
return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("sampler not impl")
|
raise RuntimeError("sampler not impl")
|
||||||
|
|
|
@ -441,7 +441,9 @@ class GPT(nn.Module):
|
||||||
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
||||||
|
|
||||||
# Pad mel codes with stop_audio_token
|
# Pad mel codes with stop_audio_token
|
||||||
audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
|
audio_codes = self.set_mel_padding(
|
||||||
|
audio_codes, code_lengths - 3
|
||||||
|
) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
|
||||||
|
|
||||||
# Build input and target tensors
|
# Build input and target tensors
|
||||||
# Prepend start token to inputs and append stop token to targets
|
# Prepend start token to inputs and append stop token to targets
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import textwrap
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
import pypinyin
|
import pypinyin
|
||||||
|
@ -8,10 +8,66 @@ import torch
|
||||||
from hangul_romanize import Transliter
|
from hangul_romanize import Transliter
|
||||||
from hangul_romanize.rule import academic
|
from hangul_romanize.rule import academic
|
||||||
from num2words import num2words
|
from num2words import num2words
|
||||||
|
from spacy.lang.ar import Arabic
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.lang.es import Spanish
|
||||||
|
from spacy.lang.ja import Japanese
|
||||||
|
from spacy.lang.zh import Chinese
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||||
|
|
||||||
|
|
||||||
|
def get_spacy_lang(lang):
|
||||||
|
if lang == "zh":
|
||||||
|
return Chinese()
|
||||||
|
elif lang == "ja":
|
||||||
|
return Japanese()
|
||||||
|
elif lang == "ar":
|
||||||
|
return Arabic()
|
||||||
|
elif lang == "es":
|
||||||
|
return Spanish()
|
||||||
|
else:
|
||||||
|
# For most languages, Enlish does the job
|
||||||
|
return English()
|
||||||
|
|
||||||
|
|
||||||
|
def split_sentence(text, lang, text_split_length=250):
|
||||||
|
"""Preprocess the input text"""
|
||||||
|
text_splits = []
|
||||||
|
if text_split_length is not None and len(text) >= text_split_length:
|
||||||
|
text_splits.append("")
|
||||||
|
nlp = get_spacy_lang(lang)
|
||||||
|
nlp.add_pipe("sentencizer")
|
||||||
|
doc = nlp(text)
|
||||||
|
for sentence in doc.sents:
|
||||||
|
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
|
||||||
|
# if the last sentence + the current sentence is less than the text_split_length
|
||||||
|
# then add the current sentence to the last sentence
|
||||||
|
text_splits[-1] += " " + str(sentence)
|
||||||
|
text_splits[-1] = text_splits[-1].lstrip()
|
||||||
|
elif len(str(sentence)) > text_split_length:
|
||||||
|
# if the current sentence is greater than the text_split_length
|
||||||
|
for line in textwrap.wrap(
|
||||||
|
str(sentence),
|
||||||
|
width=text_split_length,
|
||||||
|
drop_whitespace=True,
|
||||||
|
break_on_hyphens=False,
|
||||||
|
tabsize=1,
|
||||||
|
):
|
||||||
|
text_splits.append(str(line))
|
||||||
|
else:
|
||||||
|
text_splits.append(str(sentence))
|
||||||
|
|
||||||
|
if len(text_splits) > 1:
|
||||||
|
if text_splits[0] == "":
|
||||||
|
del text_splits[0]
|
||||||
|
else:
|
||||||
|
text_splits = [text.lstrip()]
|
||||||
|
|
||||||
|
return text_splits
|
||||||
|
|
||||||
|
|
||||||
_whitespace_re = re.compile(r"\s+")
|
_whitespace_re = re.compile(r"\s+")
|
||||||
|
|
||||||
# List of (regular expression, replacement) pairs for abbreviations:
|
# List of (regular expression, replacement) pairs for abbreviations:
|
||||||
|
@ -115,7 +171,7 @@ _abbreviations = {
|
||||||
# There are not many common abbreviations in Arabic as in English.
|
# There are not many common abbreviations in Arabic as in English.
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"zh-cn": [
|
"zh": [
|
||||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
||||||
|
@ -280,7 +336,7 @@ _symbols_multilingual = {
|
||||||
("°", " درجة "),
|
("°", " درجة "),
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
"zh-cn": [
|
"zh": [
|
||||||
# Chinese
|
# Chinese
|
||||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
for x in [
|
for x in [
|
||||||
|
@ -464,7 +520,7 @@ def _expand_number(m, lang="en"):
|
||||||
|
|
||||||
|
|
||||||
def expand_numbers_multilingual(text, lang="en"):
|
def expand_numbers_multilingual(text, lang="en"):
|
||||||
if lang == "zh" or lang == "zh-cn":
|
if lang == "zh":
|
||||||
text = zh_num2words()(text)
|
text = zh_num2words()(text)
|
||||||
else:
|
else:
|
||||||
if lang in ["en", "ru"]:
|
if lang in ["en", "ru"]:
|
||||||
|
@ -525,7 +581,7 @@ def japanese_cleaners(text, katsu):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def korean_cleaners(text):
|
def korean_transliterate(text):
|
||||||
r = Transliter(academic)
|
r = Transliter(academic)
|
||||||
return r.translit(text)
|
return r.translit(text)
|
||||||
|
|
||||||
|
@ -546,7 +602,7 @@ class VoiceBpeTokenizer:
|
||||||
"it": 213,
|
"it": 213,
|
||||||
"pt": 203,
|
"pt": 203,
|
||||||
"pl": 224,
|
"pl": 224,
|
||||||
"zh-cn": 82,
|
"zh": 82,
|
||||||
"ar": 166,
|
"ar": 166,
|
||||||
"cs": 186,
|
"cs": 186,
|
||||||
"ru": 182,
|
"ru": 182,
|
||||||
|
@ -564,6 +620,7 @@ class VoiceBpeTokenizer:
|
||||||
return cutlet.Cutlet()
|
return cutlet.Cutlet()
|
||||||
|
|
||||||
def check_input_length(self, txt, lang):
|
def check_input_length(self, txt, lang):
|
||||||
|
lang = lang.split("-")[0] # remove the region
|
||||||
limit = self.char_limits.get(lang, 250)
|
limit = self.char_limits.get(lang, 250)
|
||||||
if len(txt) > limit:
|
if len(txt) > limit:
|
||||||
print(
|
print(
|
||||||
|
@ -571,21 +628,23 @@ class VoiceBpeTokenizer:
|
||||||
)
|
)
|
||||||
|
|
||||||
def preprocess_text(self, txt, lang):
|
def preprocess_text(self, txt, lang):
|
||||||
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
|
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
|
||||||
txt = multilingual_cleaners(txt, lang)
|
txt = multilingual_cleaners(txt, lang)
|
||||||
if lang in {"zh", "zh-cn"}:
|
if lang == "zh":
|
||||||
txt = chinese_transliterate(txt)
|
txt = chinese_transliterate(txt)
|
||||||
|
if lang == "ko":
|
||||||
|
txt = korean_transliterate(txt)
|
||||||
elif lang == "ja":
|
elif lang == "ja":
|
||||||
txt = japanese_cleaners(txt, self.katsu)
|
txt = japanese_cleaners(txt, self.katsu)
|
||||||
elif lang == "ko":
|
|
||||||
txt = korean_cleaners(txt)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Language '{lang}' is not supported.")
|
raise NotImplementedError(f"Language '{lang}' is not supported.")
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
def encode(self, txt, lang):
|
def encode(self, txt, lang):
|
||||||
|
lang = lang.split("-")[0] # remove the region
|
||||||
self.check_input_length(txt, lang)
|
self.check_input_length(txt, lang)
|
||||||
txt = self.preprocess_text(txt, lang)
|
txt = self.preprocess_text(txt, lang)
|
||||||
|
lang = "zh-cn" if lang == "zh" else lang
|
||||||
txt = f"[{lang}]{txt}"
|
txt = f"[{lang}]{txt}"
|
||||||
txt = txt.replace(" ", "[SPACE]")
|
txt = txt.replace(" ", "[SPACE]")
|
||||||
return self.tokenizer.encode(txt).ids
|
return self.tokenizer.encode(txt).ids
|
||||||
|
@ -682,8 +741,8 @@ def test_expand_numbers_multilingual():
|
||||||
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
|
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
|
||||||
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
|
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
|
||||||
# Chinese (Simplified)
|
# Chinese (Simplified)
|
||||||
("在12.5秒内", "在十二点五秒内", "zh-cn"),
|
("在12.5秒内", "在十二点五秒内", "zh"),
|
||||||
("有50名士兵", "有五十名士兵", "zh-cn"),
|
("有50名士兵", "有五十名士兵", "zh"),
|
||||||
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
|
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
|
||||||
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
|
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
|
||||||
# Turkish
|
# Turkish
|
||||||
|
@ -764,7 +823,7 @@ def test_symbols_multilingual():
|
||||||
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
|
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
|
||||||
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
|
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
|
||||||
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
|
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
|
||||||
("我的电量为 14%", "我的电量为 14 百分之", "zh-cn"),
|
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
|
||||||
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
|
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
|
||||||
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
|
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
|
||||||
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
|
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
|
||||||
|
|
|
@ -318,9 +318,10 @@ class GPTTrainer(BaseTTS):
|
||||||
batch["cond_idxs"] = None
|
batch["cond_idxs"] = None
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
def on_train_epoch_start(self, trainer):
|
||||||
# guarante that dvae will be in eval mode after .train() on evaluation end
|
trainer.model.eval() # the whole model to eval
|
||||||
self.dvae = self.dvae.eval()
|
# put gpt model in training mode
|
||||||
|
trainer.model.xtts.gpt.train()
|
||||||
|
|
||||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||||
# ignore similarities.pth on clearml save/upload
|
# ignore similarities.pth on clearml save/upload
|
||||||
|
|
|
@ -10,7 +10,7 @@ from coqpit import Coqpit
|
||||||
from TTS.tts.layers.xtts.gpt import GPT
|
from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
@ -396,7 +396,7 @@ class Xtts(BaseTTS):
|
||||||
inference with config
|
inference with config
|
||||||
"""
|
"""
|
||||||
assert (
|
assert (
|
||||||
language in self.config.languages
|
"zh-cn" if language == "zh" else language in self.config.languages
|
||||||
), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}"
|
), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}"
|
||||||
# Use generally found best tuning knobs for generation.
|
# Use generally found best tuning knobs for generation.
|
||||||
settings = {
|
settings = {
|
||||||
|
@ -420,9 +420,9 @@ class Xtts(BaseTTS):
|
||||||
ref_audio_path,
|
ref_audio_path,
|
||||||
language,
|
language,
|
||||||
# GPT inference
|
# GPT inference
|
||||||
temperature=0.65,
|
temperature=0.75,
|
||||||
length_penalty=1,
|
length_penalty=1.0,
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=10.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
|
@ -502,71 +502,76 @@ class Xtts(BaseTTS):
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
# GPT inference
|
# GPT inference
|
||||||
temperature=0.65,
|
temperature=0.75,
|
||||||
length_penalty=1,
|
length_penalty=1.0,
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=10.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
speed=1.0,
|
speed=1.0,
|
||||||
|
enable_text_splitting=False,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
|
language = language.split("-")[0] # remove the country code
|
||||||
length_scale = 1.0 / max(speed, 0.05)
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
text = text.strip().lower()
|
if enable_text_splitting:
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||||
|
else:
|
||||||
|
text = [text]
|
||||||
|
|
||||||
# print(" > Input text: ", text)
|
wavs = []
|
||||||
# print(" > Input text preprocessed: ",self.tokenizer.preprocess_text(text, language))
|
gpt_latents_list = []
|
||||||
# print(" > Input tokens: ", text_tokens)
|
for sent in text:
|
||||||
# print(" > Decoded text: ", self.tokenizer.decode(text_tokens[0].cpu().numpy()))
|
sent = sent.strip().lower()
|
||||||
assert (
|
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
|
||||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
|
||||||
|
|
||||||
with torch.no_grad():
|
assert (
|
||||||
gpt_codes = self.gpt.generate(
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||||
cond_latents=gpt_cond_latent,
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||||
text_inputs=text_tokens,
|
|
||||||
input_tokens=None,
|
|
||||||
do_sample=do_sample,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
temperature=temperature,
|
|
||||||
num_return_sequences=self.gpt_batch_size,
|
|
||||||
num_beams=num_beams,
|
|
||||||
length_penalty=length_penalty,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
output_attentions=False,
|
|
||||||
**hf_generate_kwargs,
|
|
||||||
)
|
|
||||||
expected_output_len = torch.tensor(
|
|
||||||
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
|
||||||
)
|
|
||||||
|
|
||||||
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
with torch.no_grad():
|
||||||
gpt_latents = self.gpt(
|
gpt_codes = self.gpt.generate(
|
||||||
text_tokens,
|
cond_latents=gpt_cond_latent,
|
||||||
text_len,
|
text_inputs=text_tokens,
|
||||||
gpt_codes,
|
input_tokens=None,
|
||||||
expected_output_len,
|
do_sample=do_sample,
|
||||||
cond_latents=gpt_cond_latent,
|
top_p=top_p,
|
||||||
return_attentions=False,
|
top_k=top_k,
|
||||||
return_latent=True,
|
temperature=temperature,
|
||||||
)
|
num_return_sequences=self.gpt_batch_size,
|
||||||
|
num_beams=num_beams,
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
output_attentions=False,
|
||||||
|
**hf_generate_kwargs,
|
||||||
|
)
|
||||||
|
expected_output_len = torch.tensor(
|
||||||
|
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
||||||
|
)
|
||||||
|
|
||||||
if length_scale != 1.0:
|
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
||||||
gpt_latents = F.interpolate(
|
gpt_latents = self.gpt(
|
||||||
gpt_latents.transpose(1, 2),
|
text_tokens,
|
||||||
scale_factor=length_scale,
|
text_len,
|
||||||
mode="linear"
|
gpt_codes,
|
||||||
).transpose(1, 2)
|
expected_output_len,
|
||||||
|
cond_latents=gpt_cond_latent,
|
||||||
|
return_attentions=False,
|
||||||
|
return_latent=True,
|
||||||
|
)
|
||||||
|
|
||||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
if length_scale != 1.0:
|
||||||
|
gpt_latents = F.interpolate(
|
||||||
|
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
gpt_latents_list.append(gpt_latents.cpu())
|
||||||
|
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"wav": wav.cpu().numpy().squeeze(),
|
"wav": torch.cat(wavs, dim=0).numpy(),
|
||||||
"gpt_latents": gpt_latents,
|
"gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(),
|
||||||
"speaker_embedding": speaker_embedding,
|
"speaker_embedding": speaker_embedding,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -606,66 +611,76 @@ class Xtts(BaseTTS):
|
||||||
stream_chunk_size=20,
|
stream_chunk_size=20,
|
||||||
overlap_wav_len=1024,
|
overlap_wav_len=1024,
|
||||||
# GPT inference
|
# GPT inference
|
||||||
temperature=0.65,
|
temperature=0.75,
|
||||||
length_penalty=1,
|
length_penalty=1.0,
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=10.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.85,
|
top_p=0.85,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
speed=1.0,
|
speed=1.0,
|
||||||
|
enable_text_splitting=False,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
|
language = language.split("-")[0] # remove the country code
|
||||||
length_scale = 1.0 / max(speed, 0.05)
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
text = text.strip().lower()
|
if enable_text_splitting:
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||||
|
else:
|
||||||
|
text = [text]
|
||||||
|
|
||||||
fake_inputs = self.gpt.compute_embeddings(
|
for sent in text:
|
||||||
gpt_cond_latent.to(self.device),
|
sent = sent.strip().lower()
|
||||||
text_tokens,
|
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||||
)
|
|
||||||
gpt_generator = self.gpt.get_generator(
|
|
||||||
fake_inputs=fake_inputs,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
temperature=temperature,
|
|
||||||
do_sample=do_sample,
|
|
||||||
num_beams=1,
|
|
||||||
num_return_sequences=1,
|
|
||||||
length_penalty=float(length_penalty),
|
|
||||||
repetition_penalty=float(repetition_penalty),
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=True,
|
|
||||||
**hf_generate_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
last_tokens = []
|
assert (
|
||||||
all_latents = []
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||||
wav_gen_prev = None
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||||
wav_overlap = None
|
|
||||||
is_end = False
|
|
||||||
|
|
||||||
while not is_end:
|
fake_inputs = self.gpt.compute_embeddings(
|
||||||
try:
|
gpt_cond_latent.to(self.device),
|
||||||
x, latent = next(gpt_generator)
|
text_tokens,
|
||||||
last_tokens += [x]
|
)
|
||||||
all_latents += [latent]
|
gpt_generator = self.gpt.get_generator(
|
||||||
except StopIteration:
|
fake_inputs=fake_inputs,
|
||||||
is_end = True
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
do_sample=do_sample,
|
||||||
|
num_beams=1,
|
||||||
|
num_return_sequences=1,
|
||||||
|
length_penalty=float(length_penalty),
|
||||||
|
repetition_penalty=float(repetition_penalty),
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=True,
|
||||||
|
**hf_generate_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
last_tokens = []
|
||||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
all_latents = []
|
||||||
if length_scale != 1.0:
|
wav_gen_prev = None
|
||||||
gpt_latents = F.interpolate(
|
wav_overlap = None
|
||||||
gpt_latents.transpose(1, 2),
|
is_end = False
|
||||||
scale_factor=length_scale,
|
|
||||||
mode="linear"
|
while not is_end:
|
||||||
).transpose(1, 2)
|
try:
|
||||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
x, latent = next(gpt_generator)
|
||||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
last_tokens += [x]
|
||||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
all_latents += [latent]
|
||||||
)
|
except StopIteration:
|
||||||
last_tokens = []
|
is_end = True
|
||||||
yield wav_chunk
|
|
||||||
|
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||||
|
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||||
|
if length_scale != 1.0:
|
||||||
|
gpt_latents = F.interpolate(
|
||||||
|
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||||
|
).transpose(1, 2)
|
||||||
|
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||||
|
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||||
|
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||||
|
)
|
||||||
|
last_tokens = []
|
||||||
|
yield wav_chunk
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
|
@ -201,7 +201,6 @@ def stft(
|
||||||
def istft(
|
def istft(
|
||||||
*,
|
*,
|
||||||
y: np.ndarray = None,
|
y: np.ndarray = None,
|
||||||
fft_size: int = None,
|
|
||||||
hop_length: int = None,
|
hop_length: int = None,
|
||||||
win_length: int = None,
|
win_length: int = None,
|
||||||
window: str = "hann",
|
window: str = "hann",
|
||||||
|
|
|
@ -5,10 +5,26 @@ import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io.wavfile
|
import scipy.io.wavfile
|
||||||
import scipy.signal
|
import scipy.signal
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import StandardScaler
|
from TTS.tts.utils.helpers import StandardScaler
|
||||||
from TTS.utils.audio.numpy_transforms import compute_f0
|
from TTS.utils.audio.numpy_transforms import (
|
||||||
|
amp_to_db,
|
||||||
|
build_mel_basis,
|
||||||
|
compute_f0,
|
||||||
|
db_to_amp,
|
||||||
|
deemphasis,
|
||||||
|
find_endpoint,
|
||||||
|
griffin_lim,
|
||||||
|
load_wav,
|
||||||
|
mel_to_spec,
|
||||||
|
millisec_to_length,
|
||||||
|
preemphasis,
|
||||||
|
rms_volume_norm,
|
||||||
|
spec_to_mel,
|
||||||
|
stft,
|
||||||
|
trim_silence,
|
||||||
|
volume_norm,
|
||||||
|
)
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods
|
# pylint: disable=too-many-public-methods
|
||||||
|
|
||||||
|
@ -200,7 +216,9 @@ class AudioProcessor(object):
|
||||||
# setup stft parameters
|
# setup stft parameters
|
||||||
if hop_length is None:
|
if hop_length is None:
|
||||||
# compute stft parameters from given time values
|
# compute stft parameters from given time values
|
||||||
self.hop_length, self.win_length = self._stft_parameters()
|
self.win_length, self.hop_length = millisec_to_length(
|
||||||
|
frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# use stft parameters from config file
|
# use stft parameters from config file
|
||||||
self.hop_length = hop_length
|
self.hop_length = hop_length
|
||||||
|
@ -215,8 +233,13 @@ class AudioProcessor(object):
|
||||||
for key, value in members.items():
|
for key, value in members.items():
|
||||||
print(" | > {}:{}".format(key, value))
|
print(" | > {}:{}".format(key, value))
|
||||||
# create spectrogram utils
|
# create spectrogram utils
|
||||||
self.mel_basis = self._build_mel_basis()
|
self.mel_basis = build_mel_basis(
|
||||||
self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis())
|
sample_rate=self.sample_rate,
|
||||||
|
fft_size=self.fft_size,
|
||||||
|
num_mels=self.num_mels,
|
||||||
|
mel_fmax=self.mel_fmax,
|
||||||
|
mel_fmin=self.mel_fmin,
|
||||||
|
)
|
||||||
# setup scaler
|
# setup scaler
|
||||||
if stats_path and signal_norm:
|
if stats_path and signal_norm:
|
||||||
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
|
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
|
||||||
|
@ -232,35 +255,6 @@ class AudioProcessor(object):
|
||||||
return AudioProcessor(verbose=verbose, **config.audio)
|
return AudioProcessor(verbose=verbose, **config.audio)
|
||||||
return AudioProcessor(verbose=verbose, **config)
|
return AudioProcessor(verbose=verbose, **config)
|
||||||
|
|
||||||
### setting up the parameters ###
|
|
||||||
def _build_mel_basis(
|
|
||||||
self,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Build melspectrogram basis.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: melspectrogram basis.
|
|
||||||
"""
|
|
||||||
if self.mel_fmax is not None:
|
|
||||||
assert self.mel_fmax <= self.sample_rate // 2
|
|
||||||
return librosa.filters.mel(
|
|
||||||
sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
|
||||||
)
|
|
||||||
|
|
||||||
def _stft_parameters(
|
|
||||||
self,
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
"""Compute the real STFT parameters from the time values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[int, int]: hop length and window length for STFT.
|
|
||||||
"""
|
|
||||||
factor = self.frame_length_ms / self.frame_shift_ms
|
|
||||||
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
|
|
||||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
|
||||||
win_length = int(hop_length * factor)
|
|
||||||
return hop_length, win_length
|
|
||||||
|
|
||||||
### normalization ###
|
### normalization ###
|
||||||
def normalize(self, S: np.ndarray) -> np.ndarray:
|
def normalize(self, S: np.ndarray) -> np.ndarray:
|
||||||
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
|
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
|
||||||
|
@ -386,31 +380,6 @@ class AudioProcessor(object):
|
||||||
self.linear_scaler = StandardScaler()
|
self.linear_scaler = StandardScaler()
|
||||||
self.linear_scaler.set_stats(linear_mean, linear_std)
|
self.linear_scaler.set_stats(linear_mean, linear_std)
|
||||||
|
|
||||||
### DB and AMP conversion ###
|
|
||||||
# pylint: disable=no-self-use
|
|
||||||
def _amp_to_db(self, x: np.ndarray) -> np.ndarray:
|
|
||||||
"""Convert amplitude values to decibels.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (np.ndarray): Amplitude spectrogram.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Decibels spectrogram.
|
|
||||||
"""
|
|
||||||
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
|
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
|
||||||
def _db_to_amp(self, x: np.ndarray) -> np.ndarray:
|
|
||||||
"""Convert decibels spectrogram to amplitude spectrogram.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (np.ndarray): Decibels spectrogram.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Amplitude spectrogram.
|
|
||||||
"""
|
|
||||||
return _exp(x / self.spec_gain, self.base)
|
|
||||||
|
|
||||||
### Preemphasis ###
|
### Preemphasis ###
|
||||||
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||||
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||||
|
@ -424,32 +393,13 @@ class AudioProcessor(object):
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: Decorrelated audio signal.
|
np.ndarray: Decorrelated audio signal.
|
||||||
"""
|
"""
|
||||||
if self.preemphasis == 0:
|
return preemphasis(x=x, coef=self.preemphasis)
|
||||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
|
||||||
return scipy.signal.lfilter([1, -self.preemphasis], [1], x)
|
|
||||||
|
|
||||||
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||||
"""Reverse pre-emphasis."""
|
"""Reverse pre-emphasis."""
|
||||||
if self.preemphasis == 0:
|
return deemphasis(x=x, coef=self.preemphasis)
|
||||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
|
||||||
return scipy.signal.lfilter([1], [1, -self.preemphasis], x)
|
|
||||||
|
|
||||||
### SPECTROGRAMs ###
|
### SPECTROGRAMs ###
|
||||||
def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray:
|
|
||||||
"""Project a full scale spectrogram to a melspectrogram.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spectrogram (np.ndarray): Full scale spectrogram.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Melspectrogram
|
|
||||||
"""
|
|
||||||
return np.dot(self.mel_basis, spectrogram)
|
|
||||||
|
|
||||||
def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray:
|
|
||||||
"""Convert a melspectrogram to full scale spectrogram."""
|
|
||||||
return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec))
|
|
||||||
|
|
||||||
def spectrogram(self, y: np.ndarray) -> np.ndarray:
|
def spectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||||
"""Compute a spectrogram from a waveform.
|
"""Compute a spectrogram from a waveform.
|
||||||
|
|
||||||
|
@ -460,11 +410,16 @@ class AudioProcessor(object):
|
||||||
np.ndarray: Spectrogram.
|
np.ndarray: Spectrogram.
|
||||||
"""
|
"""
|
||||||
if self.preemphasis != 0:
|
if self.preemphasis != 0:
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
y = self.apply_preemphasis(y)
|
||||||
else:
|
D = stft(
|
||||||
D = self._stft(y)
|
y=y,
|
||||||
|
fft_size=self.fft_size,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length,
|
||||||
|
pad_mode=self.stft_pad_mode,
|
||||||
|
)
|
||||||
if self.do_amp_to_db_linear:
|
if self.do_amp_to_db_linear:
|
||||||
S = self._amp_to_db(np.abs(D))
|
S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base)
|
||||||
else:
|
else:
|
||||||
S = np.abs(D)
|
S = np.abs(D)
|
||||||
return self.normalize(S).astype(np.float32)
|
return self.normalize(S).astype(np.float32)
|
||||||
|
@ -472,32 +427,35 @@ class AudioProcessor(object):
|
||||||
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||||
"""Compute a melspectrogram from a waveform."""
|
"""Compute a melspectrogram from a waveform."""
|
||||||
if self.preemphasis != 0:
|
if self.preemphasis != 0:
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
y = self.apply_preemphasis(y)
|
||||||
else:
|
D = stft(
|
||||||
D = self._stft(y)
|
y=y,
|
||||||
|
fft_size=self.fft_size,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length,
|
||||||
|
pad_mode=self.stft_pad_mode,
|
||||||
|
)
|
||||||
|
S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis)
|
||||||
if self.do_amp_to_db_mel:
|
if self.do_amp_to_db_mel:
|
||||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||||
else:
|
|
||||||
S = self._linear_to_mel(np.abs(D))
|
|
||||||
return self.normalize(S).astype(np.float32)
|
return self.normalize(S).astype(np.float32)
|
||||||
|
|
||||||
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||||
S = self.denormalize(spectrogram)
|
S = self.denormalize(spectrogram)
|
||||||
S = self._db_to_amp(S)
|
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||||
# Reconstruct phase
|
# Reconstruct phase
|
||||||
if self.preemphasis != 0:
|
W = self._griffin_lim(S**self.power)
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||||
return self._griffin_lim(S**self.power)
|
|
||||||
|
|
||||||
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
||||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||||
D = self.denormalize(mel_spectrogram)
|
D = self.denormalize(mel_spectrogram)
|
||||||
S = self._db_to_amp(D)
|
S = db_to_amp(x=D, gain=self.spec_gain, base=self.base)
|
||||||
S = self._mel_to_linear(S) # Convert back to linear
|
S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear
|
||||||
if self.preemphasis != 0:
|
W = self._griffin_lim(S**self.power)
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||||
return self._griffin_lim(S**self.power)
|
|
||||||
|
|
||||||
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
||||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||||
|
@ -509,60 +467,22 @@ class AudioProcessor(object):
|
||||||
np.ndarray: Normalized melspectrogram.
|
np.ndarray: Normalized melspectrogram.
|
||||||
"""
|
"""
|
||||||
S = self.denormalize(linear_spec)
|
S = self.denormalize(linear_spec)
|
||||||
S = self._db_to_amp(S)
|
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||||
S = self._linear_to_mel(np.abs(S))
|
S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis)
|
||||||
S = self._amp_to_db(S)
|
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||||
mel = self.normalize(S)
|
mel = self.normalize(S)
|
||||||
return mel
|
return mel
|
||||||
|
|
||||||
### STFT and ISTFT ###
|
def _griffin_lim(self, S):
|
||||||
def _stft(self, y: np.ndarray) -> np.ndarray:
|
return griffin_lim(
|
||||||
"""Librosa STFT wrapper.
|
spec=S,
|
||||||
|
num_iter=self.griffin_lim_iters,
|
||||||
Args:
|
|
||||||
y (np.ndarray): Audio signal.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Complex number array.
|
|
||||||
"""
|
|
||||||
return librosa.stft(
|
|
||||||
y=y,
|
|
||||||
n_fft=self.fft_size,
|
|
||||||
hop_length=self.hop_length,
|
hop_length=self.hop_length,
|
||||||
win_length=self.win_length,
|
win_length=self.win_length,
|
||||||
|
fft_size=self.fft_size,
|
||||||
pad_mode=self.stft_pad_mode,
|
pad_mode=self.stft_pad_mode,
|
||||||
window="hann",
|
|
||||||
center=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _istft(self, y: np.ndarray) -> np.ndarray:
|
|
||||||
"""Librosa iSTFT wrapper."""
|
|
||||||
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
|
|
||||||
|
|
||||||
def _griffin_lim(self, S):
|
|
||||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
|
||||||
try:
|
|
||||||
S_complex = np.abs(S).astype(np.complex)
|
|
||||||
except AttributeError: # np.complex is deprecated since numpy 1.20.0
|
|
||||||
S_complex = np.abs(S).astype(complex)
|
|
||||||
y = self._istft(S_complex * angles)
|
|
||||||
if not np.isfinite(y).all():
|
|
||||||
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
|
|
||||||
return np.array([0.0])
|
|
||||||
for _ in range(self.griffin_lim_iters):
|
|
||||||
angles = np.exp(1j * np.angle(self._stft(y)))
|
|
||||||
y = self._istft(S_complex * angles)
|
|
||||||
return y
|
|
||||||
|
|
||||||
def compute_stft_paddings(self, x, pad_sides=1):
|
|
||||||
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
|
||||||
(first and final frames)"""
|
|
||||||
assert pad_sides in (1, 2)
|
|
||||||
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
|
|
||||||
if pad_sides == 1:
|
|
||||||
return 0, pad
|
|
||||||
return pad // 2, pad // 2 + pad % 2
|
|
||||||
|
|
||||||
def compute_f0(self, x: np.ndarray) -> np.ndarray:
|
def compute_f0(self, x: np.ndarray) -> np.ndarray:
|
||||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||||
|
|
||||||
|
@ -581,8 +501,6 @@ class AudioProcessor(object):
|
||||||
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||||
>>> pitch = ap.compute_f0(wav)
|
>>> pitch = ap.compute_f0(wav)
|
||||||
"""
|
"""
|
||||||
assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
|
||||||
assert self.pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`."
|
|
||||||
# align F0 length to the spectrogram length
|
# align F0 length to the spectrogram length
|
||||||
if len(x) % self.hop_length == 0:
|
if len(x) % self.hop_length == 0:
|
||||||
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
|
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
|
||||||
|
@ -612,21 +530,24 @@ class AudioProcessor(object):
|
||||||
Returns:
|
Returns:
|
||||||
int: Last point without silence.
|
int: Last point without silence.
|
||||||
"""
|
"""
|
||||||
window_length = int(self.sample_rate * min_silence_sec)
|
return find_endpoint(
|
||||||
hop_length = int(window_length / 4)
|
wav=wav,
|
||||||
threshold = self._db_to_amp(-self.trim_db)
|
trim_db=self.trim_db,
|
||||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
sample_rate=self.sample_rate,
|
||||||
if np.max(wav[x : x + window_length]) < threshold:
|
min_silence_sec=min_silence_sec,
|
||||||
return x + hop_length
|
gain=self.spec_gain,
|
||||||
return len(wav)
|
base=self.base,
|
||||||
|
)
|
||||||
|
|
||||||
def trim_silence(self, wav):
|
def trim_silence(self, wav):
|
||||||
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
||||||
margin = int(self.sample_rate * 0.01)
|
return trim_silence(
|
||||||
wav = wav[margin:-margin]
|
wav=wav,
|
||||||
return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[
|
sample_rate=self.sample_rate,
|
||||||
0
|
trim_db=self.trim_db,
|
||||||
]
|
win_length=self.win_length,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sound_norm(x: np.ndarray) -> np.ndarray:
|
def sound_norm(x: np.ndarray) -> np.ndarray:
|
||||||
|
@ -638,13 +559,7 @@ class AudioProcessor(object):
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: Volume normalized waveform.
|
np.ndarray: Volume normalized waveform.
|
||||||
"""
|
"""
|
||||||
return x / abs(x).max() * 0.95
|
return volume_norm(x=x)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _rms_norm(wav, db_level=-27):
|
|
||||||
r = 10 ** (db_level / 20)
|
|
||||||
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
|
|
||||||
return wav * a
|
|
||||||
|
|
||||||
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
|
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
|
||||||
"""Normalize the volume based on RMS of the signal.
|
"""Normalize the volume based on RMS of the signal.
|
||||||
|
@ -657,9 +572,7 @@ class AudioProcessor(object):
|
||||||
"""
|
"""
|
||||||
if db_level is None:
|
if db_level is None:
|
||||||
db_level = self.db_level
|
db_level = self.db_level
|
||||||
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
|
return rms_volume_norm(x=x, db_level=db_level)
|
||||||
wav = self._rms_norm(x, db_level)
|
|
||||||
return wav
|
|
||||||
|
|
||||||
### save and load ###
|
### save and load ###
|
||||||
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
||||||
|
@ -674,15 +587,10 @@ class AudioProcessor(object):
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: Loaded waveform.
|
np.ndarray: Loaded waveform.
|
||||||
"""
|
"""
|
||||||
if self.resample:
|
if sr is not None:
|
||||||
# loading with resampling. It is significantly slower.
|
x = load_wav(filename=filename, sample_rate=sr, resample=True)
|
||||||
x, sr = librosa.load(filename, sr=self.sample_rate)
|
|
||||||
elif sr is None:
|
|
||||||
# SF is faster than librosa for loading files
|
|
||||||
x, sr = sf.read(filename)
|
|
||||||
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
|
|
||||||
else:
|
else:
|
||||||
x, sr = librosa.load(filename, sr=sr)
|
x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample)
|
||||||
if self.do_trim_silence:
|
if self.do_trim_silence:
|
||||||
try:
|
try:
|
||||||
x = self.trim_silence(x)
|
x = self.trim_silence(x)
|
||||||
|
@ -723,55 +631,3 @@ class AudioProcessor(object):
|
||||||
filename (str): Path to the wav file.
|
filename (str): Path to the wav file.
|
||||||
"""
|
"""
|
||||||
return librosa.get_duration(filename=filename)
|
return librosa.get_duration(filename=filename)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
|
|
||||||
mu = 2**qc - 1
|
|
||||||
# wav_abs = np.minimum(np.abs(wav), 1.0)
|
|
||||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
|
||||||
# Quantize signal to the specified number of levels.
|
|
||||||
signal = (signal + 1) / 2 * mu + 0.5
|
|
||||||
return np.floor(
|
|
||||||
signal,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mulaw_decode(wav, qc):
|
|
||||||
"""Recovers waveform from quantized values."""
|
|
||||||
mu = 2**qc - 1
|
|
||||||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def encode_16bits(x):
|
|
||||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
|
|
||||||
"""Quantize a waveform to a given number of bits.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
|
|
||||||
bits (int): Number of quantization bits.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Quantized waveform.
|
|
||||||
"""
|
|
||||||
return (x + 1.0) * (2**bits - 1) / 2
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def dequantize(x, bits):
|
|
||||||
"""Dequantize a waveform from the given number of bits."""
|
|
||||||
return 2 * x / (2**bits - 1) - 1
|
|
||||||
|
|
||||||
|
|
||||||
def _log(x, base):
|
|
||||||
if base == 10:
|
|
||||||
return np.log10(x)
|
|
||||||
return np.log(x)
|
|
||||||
|
|
||||||
|
|
||||||
def _exp(x, base):
|
|
||||||
if base == 10:
|
|
||||||
return np.power(10, x)
|
|
||||||
return np.exp(x)
|
|
||||||
|
|
|
@ -94,6 +94,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
use_cache: bool = True
|
use_cache: bool = True
|
||||||
steps_to_start_discriminator: int = 200000
|
steps_to_start_discriminator: int = 200000
|
||||||
|
target_loss: str = "loss_1"
|
||||||
|
|
||||||
# LOSS PARAMETERS - overrides
|
# LOSS PARAMETERS - overrides
|
||||||
use_stft_loss: bool = True
|
use_stft_loss: bool = True
|
||||||
|
|
|
@ -7,6 +7,7 @@ from coqpit import Coqpit
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||||
|
|
||||||
|
|
||||||
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
||||||
|
@ -29,7 +30,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
||||||
mel = ap.melspectrogram(y)
|
mel = ap.melspectrogram(y)
|
||||||
np.save(mel_path, mel)
|
np.save(mel_path, mel)
|
||||||
if isinstance(config.mode, int):
|
if isinstance(config.mode, int):
|
||||||
quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode)
|
quant = (
|
||||||
|
mulaw_encode(wav=y, mulaw_qc=config.mode)
|
||||||
|
if config.model_args.mulaw
|
||||||
|
else quantize(x=y, quantize_bits=config.mode)
|
||||||
|
)
|
||||||
np.save(quant_path, quant)
|
np.save(quant_path, quant)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||||
|
|
||||||
|
|
||||||
class WaveRNNDataset(Dataset):
|
class WaveRNNDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
|
@ -66,7 +68,9 @@ class WaveRNNDataset(Dataset):
|
||||||
x_input = audio
|
x_input = audio
|
||||||
elif isinstance(self.mode, int):
|
elif isinstance(self.mode, int):
|
||||||
x_input = (
|
x_input = (
|
||||||
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)
|
mulaw_encode(wav=audio, mulaw_qc=self.mode)
|
||||||
|
if self.mulaw
|
||||||
|
else quantize(x=audio, quantize_bits=self.mode)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||||
|
|
|
@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||||
|
@ -399,7 +400,7 @@ class Wavernn(BaseVocoder):
|
||||||
output = output[0]
|
output = output[0]
|
||||||
|
|
||||||
if self.args.mulaw and isinstance(self.args.mode, int):
|
if self.args.mulaw and isinstance(self.args.mode, int):
|
||||||
output = AudioProcessor.mulaw_decode(output, self.args.mode)
|
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
|
||||||
|
|
||||||
# Fade-out at the end to avoid signal cutting out suddenly
|
# Fade-out at the end to avoid signal cutting out suddenly
|
||||||
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
|
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
|
||||||
|
|
|
@ -13,23 +13,28 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
|
||||||
"import sys\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import importlib\n",
|
"import importlib\n",
|
||||||
"import numpy as np\n",
|
"import os\n",
|
||||||
"from tqdm import tqdm\n",
|
|
||||||
"from torch.utils.data import DataLoader\n",
|
|
||||||
"import soundfile as sf\n",
|
|
||||||
"import pickle\n",
|
"import pickle\n",
|
||||||
|
"\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import soundfile as sf\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from matplotlib import pylab as plt\n",
|
||||||
|
"from torch.utils.data import DataLoader\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"\n",
|
||||||
|
"from TTS.config import load_config\n",
|
||||||
|
"from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
|
||||||
|
"from TTS.tts.datasets import load_tts_samples\n",
|
||||||
"from TTS.tts.datasets.dataset import TTSDataset\n",
|
"from TTS.tts.datasets.dataset import TTSDataset\n",
|
||||||
"from TTS.tts.layers.losses import L1LossMasked\n",
|
"from TTS.tts.layers.losses import L1LossMasked\n",
|
||||||
"from TTS.utils.audio import AudioProcessor\n",
|
|
||||||
"from TTS.config import load_config\n",
|
|
||||||
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
|
||||||
"from TTS.tts.utils.helpers import sequence_mask\n",
|
|
||||||
"from TTS.tts.models import setup_model\n",
|
"from TTS.tts.models import setup_model\n",
|
||||||
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n",
|
"from TTS.tts.utils.helpers import sequence_mask\n",
|
||||||
|
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
|
||||||
|
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
||||||
|
"from TTS.utils.audio import AudioProcessor\n",
|
||||||
|
"from TTS.utils.audio.numpy_transforms import quantize\n",
|
||||||
"\n",
|
"\n",
|
||||||
"%matplotlib inline\n",
|
"%matplotlib inline\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -49,11 +54,9 @@
|
||||||
" file_name = wav_file.split('.')[0]\n",
|
" file_name = wav_file.split('.')[0]\n",
|
||||||
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
|
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
|
||||||
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
|
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
|
||||||
" os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n",
|
|
||||||
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
|
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
|
||||||
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
|
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
|
||||||
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
|
" return file_name, wavq_path, mel_path"
|
||||||
" return file_name, wavq_path, mel_path, wav_path"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -65,14 +68,14 @@
|
||||||
"# Paths and configurations\n",
|
"# Paths and configurations\n",
|
||||||
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
|
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
|
||||||
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
|
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
|
||||||
|
"PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
|
||||||
"DATASET = \"ljspeech\"\n",
|
"DATASET = \"ljspeech\"\n",
|
||||||
"METADATA_FILE = \"metadata.csv\"\n",
|
"METADATA_FILE = \"metadata.csv\"\n",
|
||||||
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
|
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
|
||||||
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
|
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
|
||||||
"BATCH_SIZE = 32\n",
|
"BATCH_SIZE = 32\n",
|
||||||
"\n",
|
"\n",
|
||||||
"QUANTIZED_WAV = False\n",
|
"QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n",
|
||||||
"QUANTIZE_BIT = None\n",
|
|
||||||
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Check CUDA availability\n",
|
"# Check CUDA availability\n",
|
||||||
|
@ -80,10 +83,10 @@
|
||||||
"print(\" > CUDA enabled: \", use_cuda)\n",
|
"print(\" > CUDA enabled: \", use_cuda)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Load the configuration\n",
|
"# Load the configuration\n",
|
||||||
|
"dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
|
||||||
"C = load_config(CONFIG_PATH)\n",
|
"C = load_config(CONFIG_PATH)\n",
|
||||||
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
|
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
|
||||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
|
"ap = AudioProcessor(**C.audio)"
|
||||||
"print(C['r'])"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -92,12 +95,10 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# If the vocabulary was passed, replace the default\n",
|
"# Initialize the tokenizer\n",
|
||||||
"if 'characters' in C and C['characters']:\n",
|
"tokenizer, C = TTSTokenizer.init_from_config(C)\n",
|
||||||
" symbols, phonemes = make_symbols(**C.characters)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Load the model\n",
|
"# Load the model\n",
|
||||||
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
|
|
||||||
"# TODO: multiple speakers\n",
|
"# TODO: multiple speakers\n",
|
||||||
"model = setup_model(C)\n",
|
"model = setup_model(C)\n",
|
||||||
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
|
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
|
||||||
|
@ -109,42 +110,21 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Load the preprocessor based on the dataset\n",
|
"# Load data instances\n",
|
||||||
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
|
"meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
|
||||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
"meta_data = meta_data_train + meta_data_eval\n",
|
||||||
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
|
"\n",
|
||||||
"dataset = TTSDataset(\n",
|
"dataset = TTSDataset(\n",
|
||||||
" C,\n",
|
" outputs_per_step=C[\"r\"],\n",
|
||||||
" C.text_cleaner,\n",
|
" compute_linear_spec=False,\n",
|
||||||
" False,\n",
|
" ap=ap,\n",
|
||||||
" ap,\n",
|
" samples=meta_data,\n",
|
||||||
" meta_data,\n",
|
" tokenizer=tokenizer,\n",
|
||||||
" characters=C.get('characters', None),\n",
|
" phoneme_cache_path=PHONEME_CACHE_PATH,\n",
|
||||||
" use_phonemes=C.use_phonemes,\n",
|
|
||||||
" phoneme_cache_path=C.phoneme_cache_path,\n",
|
|
||||||
" enable_eos_bos=C.enable_eos_bos_chars,\n",
|
|
||||||
")\n",
|
")\n",
|
||||||
"loader = DataLoader(\n",
|
"loader = DataLoader(\n",
|
||||||
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
|
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
|
||||||
")\n"
|
")"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Initialize lists for storing results\n",
|
|
||||||
"file_idxs = []\n",
|
|
||||||
"metadata = []\n",
|
|
||||||
"losses = []\n",
|
|
||||||
"postnet_losses = []\n",
|
|
||||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
|
||||||
"\n",
|
|
||||||
"# Create log file\n",
|
|
||||||
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
|
||||||
"log_file = open(log_file_path, \"w\")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -160,26 +140,33 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# Initialize lists for storing results\n",
|
||||||
|
"file_idxs = []\n",
|
||||||
|
"metadata = []\n",
|
||||||
|
"losses = []\n",
|
||||||
|
"postnet_losses = []\n",
|
||||||
|
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||||
|
"\n",
|
||||||
"# Start processing with a progress bar\n",
|
"# Start processing with a progress bar\n",
|
||||||
"with torch.no_grad():\n",
|
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
||||||
|
"with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
|
||||||
" for data in tqdm(loader, desc=\"Processing\"):\n",
|
" for data in tqdm(loader, desc=\"Processing\"):\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" # setup input data\n",
|
|
||||||
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
|
|
||||||
"\n",
|
|
||||||
" # dispatch data to GPU\n",
|
" # dispatch data to GPU\n",
|
||||||
" if use_cuda:\n",
|
" if use_cuda:\n",
|
||||||
" text_input = text_input.cuda()\n",
|
" data[\"token_id\"] = data[\"token_id\"].cuda()\n",
|
||||||
" text_lengths = text_lengths.cuda()\n",
|
" data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
|
||||||
" mel_input = mel_input.cuda()\n",
|
" data[\"mel\"] = data[\"mel\"].cuda()\n",
|
||||||
" mel_lengths = mel_lengths.cuda()\n",
|
" data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" mask = sequence_mask(text_lengths)\n",
|
" mask = sequence_mask(data[\"token_id_lengths\"])\n",
|
||||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
" outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
|
||||||
|
" mel_outputs = outputs[\"decoder_outputs\"]\n",
|
||||||
|
" postnet_outputs = outputs[\"model_outputs\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # compute loss\n",
|
" # compute loss\n",
|
||||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
" loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
|
||||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
" loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
|
||||||
" losses.append(loss.item())\n",
|
" losses.append(loss.item())\n",
|
||||||
" postnet_losses.append(loss_postnet.item())\n",
|
" postnet_losses.append(loss_postnet.item())\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -193,28 +180,27 @@
|
||||||
" postnet_outputs = torch.stack(mel_specs)\n",
|
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||||
" elif C.model == \"Tacotron2\":\n",
|
" elif C.model == \"Tacotron2\":\n",
|
||||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||||
" alignments = alignments.detach().cpu().numpy()\n",
|
" alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if not DRY_RUN:\n",
|
" if not DRY_RUN:\n",
|
||||||
" for idx in range(text_input.shape[0]):\n",
|
" for idx in range(data[\"token_id\"].shape[0]):\n",
|
||||||
" wav_file_path = item_idx[idx]\n",
|
" wav_file_path = data[\"item_idxs\"][idx]\n",
|
||||||
" wav = ap.load_wav(wav_file_path)\n",
|
" wav = ap.load_wav(wav_file_path)\n",
|
||||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
" file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||||
" file_idxs.append(file_name)\n",
|
" file_idxs.append(file_name)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # quantize and save wav\n",
|
" # quantize and save wav\n",
|
||||||
" if QUANTIZED_WAV:\n",
|
" if QUANTIZE_BITS > 0:\n",
|
||||||
" wavq = ap.quantize(wav)\n",
|
" wavq = quantize(wav, QUANTIZE_BITS)\n",
|
||||||
" np.save(wavq_path, wavq)\n",
|
" np.save(wavq_path, wavq)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # save TTS mel\n",
|
" # save TTS mel\n",
|
||||||
" mel = postnet_outputs[idx]\n",
|
" mel = postnet_outputs[idx]\n",
|
||||||
" mel_length = mel_lengths[idx]\n",
|
" mel_length = data[\"mel_lengths\"][idx]\n",
|
||||||
" mel = mel[:mel_length, :].T\n",
|
" mel = mel[:mel_length, :].T\n",
|
||||||
" np.save(mel_path, mel)\n",
|
" np.save(mel_path, mel)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" metadata.append([wav_file_path, mel_path])\n",
|
" metadata.append([wav_file_path, mel_path])\n",
|
||||||
"\n",
|
|
||||||
" except Exception as e:\n",
|
" except Exception as e:\n",
|
||||||
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
|
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -224,35 +210,20 @@
|
||||||
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
|
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
|
||||||
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
|
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Close the log file\n",
|
|
||||||
"log_file.close()\n",
|
|
||||||
"\n",
|
|
||||||
"# For wavernn\n",
|
"# For wavernn\n",
|
||||||
"if not DRY_RUN:\n",
|
"if not DRY_RUN:\n",
|
||||||
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
|
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# For pwgan\n",
|
"# For pwgan\n",
|
||||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||||
" for data in metadata:\n",
|
" for wav_file_path, mel_path in metadata:\n",
|
||||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
" f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print mean losses\n",
|
"# Print mean losses\n",
|
||||||
"print(f\"Mean Loss: {mean_loss}\")\n",
|
"print(f\"Mean Loss: {mean_loss}\")\n",
|
||||||
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
|
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# for pwgan\n",
|
|
||||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
|
||||||
" for data in metadata:\n",
|
|
||||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -267,7 +238,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"idx = 1\n",
|
"idx = 1\n",
|
||||||
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
|
"ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -276,10 +247,9 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import soundfile as sf\n",
|
"wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
|
||||||
"wav, sr = sf.read(item_idx[idx])\n",
|
"mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
|
||||||
"mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n",
|
"mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
|
||||||
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
|
|
||||||
"mel_truth = ap.melspectrogram(wav)\n",
|
"mel_truth = ap.melspectrogram(wav)\n",
|
||||||
"print(mel_truth.shape)"
|
"print(mel_truth.shape)"
|
||||||
]
|
]
|
||||||
|
@ -291,7 +261,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# plot posnet output\n",
|
"# plot posnet output\n",
|
||||||
"print(mel_postnet[:mel_lengths[idx], :].shape)\n",
|
"print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
|
||||||
"plot_spectrogram(mel_postnet, ap)"
|
"plot_spectrogram(mel_postnet, ap)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -324,10 +294,9 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# postnet, decoder diff\n",
|
"# postnet, decoder diff\n",
|
||||||
"from matplotlib import pylab as plt\n",
|
|
||||||
"mel_diff = mel_decoder - mel_postnet\n",
|
"mel_diff = mel_decoder - mel_postnet\n",
|
||||||
"plt.figure(figsize=(16, 10))\n",
|
"plt.figure(figsize=(16, 10))\n",
|
||||||
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
|
"plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||||
"plt.colorbar()\n",
|
"plt.colorbar()\n",
|
||||||
"plt.tight_layout()"
|
"plt.tight_layout()"
|
||||||
]
|
]
|
||||||
|
@ -339,10 +308,9 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# PLOT GT SPECTROGRAM diff\n",
|
"# PLOT GT SPECTROGRAM diff\n",
|
||||||
"from matplotlib import pylab as plt\n",
|
|
||||||
"mel_diff2 = mel_truth.T - mel_decoder\n",
|
"mel_diff2 = mel_truth.T - mel_decoder\n",
|
||||||
"plt.figure(figsize=(16, 10))\n",
|
"plt.figure(figsize=(16, 10))\n",
|
||||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
|
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||||
"plt.colorbar()\n",
|
"plt.colorbar()\n",
|
||||||
"plt.tight_layout()"
|
"plt.tight_layout()"
|
||||||
]
|
]
|
||||||
|
@ -354,21 +322,13 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# PLOT GT SPECTROGRAM diff\n",
|
"# PLOT GT SPECTROGRAM diff\n",
|
||||||
"from matplotlib import pylab as plt\n",
|
|
||||||
"mel = postnet_outputs[idx]\n",
|
"mel = postnet_outputs[idx]\n",
|
||||||
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
|
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
|
||||||
"plt.figure(figsize=(16, 10))\n",
|
"plt.figure(figsize=(16, 10))\n",
|
||||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
|
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||||
"plt.colorbar()\n",
|
"plt.colorbar()\n",
|
||||||
"plt.tight_layout()"
|
"plt.tight_layout()"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
|
@ -1,33 +1,33 @@
|
||||||
# core deps
|
# core deps
|
||||||
numpy==1.22.0;python_version<="3.10"
|
numpy==1.22.0;python_version<="3.10"
|
||||||
numpy==1.24.3;python_version>"3.10"
|
numpy>=1.24.3;python_version>"3.10"
|
||||||
cython==0.29.30
|
cython>=0.29.30
|
||||||
scipy>=1.11.2
|
scipy>=1.11.2
|
||||||
torch>=2.1
|
torch>=2.1
|
||||||
torchaudio
|
torchaudio
|
||||||
soundfile==0.12.*
|
soundfile>=0.12.0
|
||||||
librosa==0.10.*
|
librosa>=0.10.0
|
||||||
scikit-learn==1.3.0
|
scikit-learn>=1.3.0
|
||||||
numba==0.55.1;python_version<"3.9"
|
numba==0.55.1;python_version<"3.9"
|
||||||
numba==0.57.0;python_version>="3.9"
|
numba>=0.57.0;python_version>="3.9"
|
||||||
inflect==5.6.*
|
inflect>=5.6.0
|
||||||
tqdm==4.64.*
|
tqdm>=4.64.1
|
||||||
anyascii==0.3.*
|
anyascii>=0.3.0
|
||||||
pyyaml==6.*
|
pyyaml>=6.0
|
||||||
fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail
|
fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail
|
||||||
aiohttp==3.8.*
|
aiohttp>=3.8.1
|
||||||
packaging==23.1
|
packaging>=23.1
|
||||||
# deps for examples
|
# deps for examples
|
||||||
flask==2.*
|
flask>=2.0.1
|
||||||
# deps for inference
|
# deps for inference
|
||||||
pysbd==0.3.4
|
pysbd>=0.3.4
|
||||||
# deps for notebooks
|
# deps for notebooks
|
||||||
umap-learn==0.5.*
|
umap-learn>=0.5.1
|
||||||
pandas>=1.4,<2.0
|
pandas>=1.4,<2.0
|
||||||
# deps for training
|
# deps for training
|
||||||
matplotlib==3.7.*
|
matplotlib>=3.7.0
|
||||||
# coqui stack
|
# coqui stack
|
||||||
trainer
|
trainer>=0.0.32
|
||||||
# config management
|
# config management
|
||||||
coqpit>=0.0.16
|
coqpit>=0.0.16
|
||||||
# chinese g2p deps
|
# chinese g2p deps
|
||||||
|
@ -46,11 +46,11 @@ bangla
|
||||||
bnnumerizer
|
bnnumerizer
|
||||||
bnunicodenormalizer
|
bnunicodenormalizer
|
||||||
#deps for tortoise
|
#deps for tortoise
|
||||||
k_diffusion
|
einops>=0.6.0
|
||||||
einops==0.6.*
|
transformers>=4.33.0
|
||||||
transformers==4.33.*
|
|
||||||
#deps for bark
|
#deps for bark
|
||||||
encodec==0.1.*
|
encodec>=0.1.1
|
||||||
# deps for XTTS
|
# deps for XTTS
|
||||||
unidecode==1.3.*
|
unidecode>=1.3.2
|
||||||
num2words
|
num2words
|
||||||
|
spacy[ja]>=3
|
|
@ -5,6 +5,7 @@ import torch
|
||||||
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||||
from TTS.config import BaseAudioConfig
|
from TTS.config import BaseAudioConfig
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.audio.numpy_transforms import stft
|
||||||
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
|
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
|
||||||
|
|
||||||
TESTS_PATH = get_tests_path()
|
TESTS_PATH = get_tests_path()
|
||||||
|
@ -21,7 +22,7 @@ def test_torch_stft():
|
||||||
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
|
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
|
||||||
# librosa stft
|
# librosa stft
|
||||||
wav = ap.load_wav(WAV_FILE)
|
wav = ap.load_wav(WAV_FILE)
|
||||||
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
|
M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length))
|
||||||
# torch stft
|
# torch stft
|
||||||
wav = torch.from_numpy(wav[None, :]).float()
|
wav = torch.from_numpy(wav[None, :]).float()
|
||||||
M_torch = torch_stft(wav)
|
M_torch = torch_stft(wav)
|
||||||
|
|
|
@ -186,7 +186,7 @@ def test_xtts_v2_streaming():
|
||||||
"en",
|
"en",
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
speed=1.5
|
speed=1.5,
|
||||||
)
|
)
|
||||||
wav_chuncks = []
|
wav_chuncks = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
|
@ -198,7 +198,7 @@ def test_xtts_v2_streaming():
|
||||||
"en",
|
"en",
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
speed=0.66
|
speed=0.66,
|
||||||
)
|
)
|
||||||
wav_chuncks = []
|
wav_chuncks = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
|
|
Loading…
Reference in New Issue