mirror of https://github.com/coqui-ai/TTS.git
added feature preprocessing if not set in config
This commit is contained in:
parent
9a120f28ed
commit
995d84f6d7
TTS
|
@ -29,7 +29,12 @@ from TTS.utils.generic_utils import (
|
||||||
set_init_dict,
|
set_init_dict,
|
||||||
)
|
)
|
||||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import (
|
||||||
|
load_wav_data,
|
||||||
|
find_feat_files,
|
||||||
|
load_wav_feat_data,
|
||||||
|
preprocess_wav_files,
|
||||||
|
)
|
||||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||||
from TTS.vocoder.utils.generic_utils import setup_wavernn
|
from TTS.vocoder.utils.generic_utils import setup_wavernn
|
||||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||||
|
@ -192,15 +197,17 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
)
|
)
|
||||||
predict_mel = ap.melspectrogram(sample_wav)
|
predict_mel = ap.melspectrogram(sample_wav)
|
||||||
|
|
||||||
# Sample audio
|
|
||||||
tb_logger.tb_train_audios(
|
|
||||||
global_step, {"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
|
||||||
)
|
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
figures = {
|
figures = {
|
||||||
"prediction": plot_spectrogram(predict_mel.T),
|
"train/ground_truth": plot_spectrogram(ground_mel.T),
|
||||||
"ground_truth": plot_spectrogram(ground_mel.T),
|
"train/prediction": plot_spectrogram(predict_mel.T),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
tb_logger.tb_train_audios(
|
||||||
|
global_step, {"train/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
||||||
|
)
|
||||||
|
|
||||||
tb_logger.tb_train_figures(global_step, figures)
|
tb_logger.tb_train_figures(global_step, figures)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
|
@ -235,7 +242,6 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
y_hat = model(x, m)
|
y_hat = model(x, m)
|
||||||
y_hat_viz = y_hat # for vizualization
|
|
||||||
if isinstance(model.mode, int):
|
if isinstance(model.mode, int):
|
||||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
|
@ -263,11 +269,11 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
if CONFIG.print_eval:
|
if CONFIG.print_eval:
|
||||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||||
|
|
||||||
if epoch > CONFIG.test_delay_epochs:
|
if epoch % CONFIG.test_every_epochs == 0:
|
||||||
# synthesize a full voice
|
# synthesize a part of data
|
||||||
wav_path = train_data[random.randrange(0, len(train_data))][0]
|
wav_path = eval_data[random.randrange(0, len(eval_data))][0]
|
||||||
wav = ap.load_wav(wav_path)
|
wav = ap.load_wav(wav_path)
|
||||||
ground_mel = ap.melspectrogram(wav)
|
ground_mel = ap.melspectrogram(wav[:22000])
|
||||||
sample_wav = model.generate(
|
sample_wav = model.generate(
|
||||||
ground_mel,
|
ground_mel,
|
||||||
CONFIG.batched,
|
CONFIG.batched,
|
||||||
|
@ -276,15 +282,17 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
)
|
)
|
||||||
predict_mel = ap.melspectrogram(sample_wav)
|
predict_mel = ap.melspectrogram(sample_wav)
|
||||||
|
|
||||||
|
# compute spectrograms
|
||||||
|
figures = {
|
||||||
|
"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
||||||
|
"eval/prediction": plot_spectrogram(predict_mel.T),
|
||||||
|
}
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
tb_logger.tb_eval_audios(
|
tb_logger.tb_eval_audios(
|
||||||
global_step, {"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
global_step, {"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
||||||
)
|
)
|
||||||
# compute spectrograms
|
|
||||||
figures = {
|
|
||||||
"eval/prediction": plot_spectrogram(predict_mel.T),
|
|
||||||
"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
|
||||||
}
|
|
||||||
tb_logger.tb_eval_figures(global_step, figures)
|
tb_logger.tb_eval_figures(global_step, figures)
|
||||||
|
|
||||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||||
|
@ -296,6 +304,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=global-variable-undefined
|
# pylint: disable=global-variable-undefined
|
||||||
global train_data, eval_data
|
global train_data, eval_data
|
||||||
|
|
||||||
|
# setup audio processor
|
||||||
|
ap = AudioProcessor(**CONFIG.audio)
|
||||||
|
|
||||||
print(f" > Loading wavs from: {CONFIG.data_path}")
|
print(f" > Loading wavs from: {CONFIG.data_path}")
|
||||||
if CONFIG.feature_path is not None:
|
if CONFIG.feature_path is not None:
|
||||||
print(f" > Loading features from: {CONFIG.feature_path}")
|
print(f" > Loading features from: {CONFIG.feature_path}")
|
||||||
|
@ -303,11 +314,20 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size
|
CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
eval_data, train_data = load_wav_data(CONFIG.data_path, CONFIG.eval_split_size)
|
mel_feat_path = os.path.join(OUT_PATH, "mel")
|
||||||
|
feat_data = find_feat_files(mel_feat_path)
|
||||||
# setup audio processor
|
if feat_data:
|
||||||
ap = AudioProcessor(**CONFIG.audio)
|
print(f" > Loading features from: {mel_feat_path}")
|
||||||
|
eval_data, train_data = load_wav_feat_data(
|
||||||
|
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f" > No feature data found. Preprocessing...")
|
||||||
|
# preprocessing feature data from given wav files
|
||||||
|
preprocess_wav_files(OUT_PATH, CONFIG, ap)
|
||||||
|
eval_data, train_data = load_wav_feat_data(
|
||||||
|
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
|
||||||
|
)
|
||||||
# setup model
|
# setup model
|
||||||
model_wavernn = setup_wavernn(CONFIG)
|
model_wavernn = setup_wavernn(CONFIG)
|
||||||
|
|
||||||
|
|
|
@ -55,18 +55,17 @@
|
||||||
"padding": 2, // pad the input for resnet to see wider input length
|
"padding": 2, // pad the input for resnet to see wider input length
|
||||||
|
|
||||||
// DATASET
|
// DATASET
|
||||||
"use_gta": true, // use computed gta features from the tts model
|
//"use_gta": true, // use computed gta features from the tts model
|
||||||
"data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech/", // path containing training wav files
|
"data_path": "path/to/wav/files", // path containing training wav files
|
||||||
"feature_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech_Computed/mel/", // path containing computed features .npy (mels / quant)
|
"feature_path": null, // path containing computed features from wav files if null compute them
|
||||||
|
|
||||||
// TRAINING
|
// TRAINING
|
||||||
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||||
"epochs": 10000, // total number of epochs to train.
|
"epochs": 10000, // total number of epochs to train.
|
||||||
"warmup_steps": 10,
|
|
||||||
|
|
||||||
// VALIDATION
|
// VALIDATION
|
||||||
"run_eval": true,
|
"run_eval": true,
|
||||||
"test_delay_epochs": 10, // early testing only wastes computation time.
|
"test_every_epochs": 10, // Test after set number of epochs (Test every 20 epochs for example)
|
||||||
|
|
||||||
// OPTIMIZER
|
// OPTIMIZER
|
||||||
"grad_clip": 4, // apply gradient clipping if > 0
|
"grad_clip": 4, // apply gradient clipping if > 0
|
||||||
|
@ -90,6 +89,6 @@
|
||||||
"eval_split_size": 50, // number of samples for testing
|
"eval_split_size": 50, // number of samples for testing
|
||||||
|
|
||||||
// PATHS
|
// PATHS
|
||||||
"output_path": "/media/alexander/LinuxFS/Projects/wavernn/Trainings/"
|
"output_path": "output/training/path"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,38 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_wav_files(out_path, config, ap):
|
||||||
|
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
||||||
|
wav_files = find_wav_files(config.data_path)
|
||||||
|
for path in tqdm(wav_files):
|
||||||
|
wav_name = Path(path).stem
|
||||||
|
quant_path = os.path.join(out_path, "quant", wav_name + ".npy")
|
||||||
|
mel_path = os.path.join(out_path, "mel", wav_name + ".npy")
|
||||||
|
y = ap.load_wav(path)
|
||||||
|
mel = ap.melspectrogram(y)
|
||||||
|
np.save(mel_path, mel)
|
||||||
|
if isinstance(config.mode, int):
|
||||||
|
quant = (
|
||||||
|
ap.mulaw_encode(y, qc=config.mode)
|
||||||
|
if config.mulaw
|
||||||
|
else ap.quantize(y, bits=config.mode)
|
||||||
|
)
|
||||||
|
np.save(quant_path, quant)
|
||||||
|
|
||||||
|
|
||||||
def find_wav_files(data_path):
|
def find_wav_files(data_path):
|
||||||
wav_paths = glob.glob(os.path.join(data_path, '**', '*.wav'), recursive=True)
|
wav_paths = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
||||||
return wav_paths
|
return wav_paths
|
||||||
|
|
||||||
|
|
||||||
def find_feat_files(data_path):
|
def find_feat_files(data_path):
|
||||||
feat_paths = glob.glob(os.path.join(data_path, '**', '*.npy'), recursive=True)
|
feat_paths = glob.glob(os.path.join(data_path, "**", "*.npy"), recursive=True)
|
||||||
return feat_paths
|
return feat_paths
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,7 @@ class WaveRNNDataset(Dataset):
|
||||||
feat_path = self.item_list[index]
|
feat_path = self.item_list[index]
|
||||||
m = np.load(feat_path.replace("/quant/", "/mel/"))
|
m = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||||
if self.mode in ["gauss", "mold"]:
|
if self.mode in ["gauss", "mold"]:
|
||||||
|
# x = np.load(feat_path.replace("/mel/", "/quant/"))
|
||||||
x = self.ap.load_wav(wavpath)
|
x = self.ap.load_wav(wavpath)
|
||||||
elif isinstance(self.mode, int):
|
elif isinstance(self.mode, int):
|
||||||
x = np.load(feat_path.replace("/mel/", "/quant/"))
|
x = np.load(feat_path.replace("/mel/", "/quant/"))
|
||||||
|
|
Loading…
Reference in New Issue