compute audio feat on dataload

This commit is contained in:
sanjaesc 2020-10-25 09:45:37 +01:00 committed by erogol
parent 7c72562fe7
commit bef3f2020b
4 changed files with 243 additions and 203 deletions

View File

@ -29,8 +29,8 @@ from TTS.utils.generic_utils import (
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.datasets.preprocess import ( from TTS.vocoder.datasets.preprocess import (
find_feat_files, find_feat_files,
load_wav_feat_data, load_wav_data,
preprocess_wav_files, load_wav_feat_data
) )
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
from TTS.vocoder.utils.generic_utils import setup_wavernn from TTS.vocoder.utils.generic_utils import setup_wavernn
@ -41,15 +41,16 @@ use_cuda, num_gpus = setup_torch_training_env(True, True)
def setup_loader(ap, is_val=False, verbose=False): def setup_loader(ap, is_val=False, verbose=False):
if is_val and not CONFIG.run_eval: if is_val and not c.run_eval:
loader = None loader = None
else: else:
dataset = WaveRNNDataset(ap=ap, dataset = WaveRNNDataset(ap=ap,
items=eval_data if is_val else train_data, items=eval_data if is_val else train_data,
seq_len=CONFIG.seq_len, seq_len=c.seq_len,
hop_len=ap.hop_length, hop_len=ap.hop_length,
pad=CONFIG.padding, pad=c.padding,
mode=CONFIG.mode, mode=c.mode,
mulaw=c.mulaw,
is_training=not is_val, is_training=not is_val,
verbose=verbose, verbose=verbose,
) )
@ -57,10 +58,10 @@ def setup_loader(ap, is_val=False, verbose=False):
loader = DataLoader(dataset, loader = DataLoader(dataset,
shuffle=True, shuffle=True,
collate_fn=dataset.collate, collate_fn=dataset.collate,
batch_size=CONFIG.batch_size, batch_size=c.batch_size,
num_workers=CONFIG.num_val_loader_workers num_workers=c.num_val_loader_workers
if is_val if is_val
else CONFIG.num_loader_workers, else c.num_loader_workers,
pin_memory=True, pin_memory=True,
) )
return loader return loader
@ -89,9 +90,9 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
keep_avg = KeepAverage() keep_avg = KeepAverage()
if use_cuda: if use_cuda:
batch_n_iter = int(len(data_loader.dataset) / batch_n_iter = int(len(data_loader.dataset) /
(CONFIG.batch_size * num_gpus)) (c.batch_size * num_gpus))
else: else:
batch_n_iter = int(len(data_loader.dataset) / CONFIG.batch_size) batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time() end_time = time.time()
c_logger.print_train_start() c_logger.print_train_start()
# train loop # train loop
@ -102,9 +103,6 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
loader_time = time.time() - end_time loader_time = time.time() - end_time
global_step += 1 global_step += 1
##################
# MODEL TRAINING #
##################
y_hat = model(x_input, mels) y_hat = model(x_input, mels)
if isinstance(model.mode, int): if isinstance(model.mode, int):
@ -112,7 +110,6 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
else: else:
y_coarse = y_coarse.float() y_coarse = y_coarse.float()
y_coarse = y_coarse.unsqueeze(-1) y_coarse = y_coarse.unsqueeze(-1)
# m_scaled, _ = model.upsample(m)
# compute losses # compute losses
loss = criterion(y_hat, y_coarse) loss = criterion(y_hat, y_coarse)
@ -120,11 +117,11 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
raise RuntimeError(" [!] None loss. Exiting ...") raise RuntimeError(" [!] None loss. Exiting ...")
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
if CONFIG.grad_clip > 0: if c.grad_clip > 0:
torch.nn.utils.clip_grad_norm_( torch.nn.utils.clip_grad_norm_(
model.parameters(), CONFIG.grad_clip) model.parameters(), c.grad_clip)
optimizer.step() optimizer.step()
if scheduler is not None: if scheduler is not None:
scheduler.step() scheduler.step()
@ -144,7 +141,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
keep_avg.update_values(update_train_values) keep_avg.update_values(update_train_values)
# print training stats # print training stats
if global_step % CONFIG.print_step == 0: if global_step % c.print_step == 0:
log_dict = {"step_time": [step_time, 2], log_dict = {"step_time": [step_time, 2],
"loader_time": [loader_time, 4], "loader_time": [loader_time, 4],
"current_lr": cur_lr, "current_lr": cur_lr,
@ -164,8 +161,8 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
tb_logger.tb_train_iter_stats(global_step, iter_stats) tb_logger.tb_train_iter_stats(global_step, iter_stats)
# save checkpoint # save checkpoint
if global_step % CONFIG.save_step == 0: if global_step % c.save_step == 0:
if CONFIG.checkpoint: if c.checkpoint:
# save model # save model
save_checkpoint(model, save_checkpoint(model,
optimizer, optimizer,
@ -180,28 +177,30 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
) )
# synthesize a full voice # synthesize a full voice
wav_path = train_data[random.randrange(0, len(train_data))][0] rand_idx = random.randrange(0, len(train_data))
wav_path = train_data[rand_idx] if not isinstance(
train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
wav = ap.load_wav(wav_path) wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav) ground_mel = ap.melspectrogram(wav)
sample_wav = model.generate(ground_mel, sample_wav = model.generate(ground_mel,
CONFIG.batched, c.batched,
CONFIG.target_samples, c.target_samples,
CONFIG.overlap_samples, c.overlap_samples,
use_cuda
) )
predict_mel = ap.melspectrogram(sample_wav) predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms # compute spectrograms
figures = {"train/ground_truth": plot_spectrogram(ground_mel.T), figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
"train/prediction": plot_spectrogram(predict_mel.T), "train/prediction": plot_spectrogram(predict_mel.T)
} }
tb_logger.tb_train_figures(global_step, figures)
# Sample audio # Sample audio
tb_logger.tb_train_audios( tb_logger.tb_train_audios(
global_step, { global_step, {
"train/audio": sample_wav}, CONFIG.audio["sample_rate"] "train/audio": sample_wav}, c.audio["sample_rate"]
) )
tb_logger.tb_train_figures(global_step, figures)
end_time = time.time() end_time = time.time()
# print epoch stats # print epoch stats
@ -259,34 +258,35 @@ def evaluate(model, criterion, ap, global_step, epoch):
keep_avg.update_values(update_eval_values) keep_avg.update_values(update_eval_values)
# print eval stats # print eval stats
if CONFIG.print_eval: if c.print_eval:
c_logger.print_eval_step( c_logger.print_eval_step(
num_iter, loss_dict, keep_avg.avg_values) num_iter, loss_dict, keep_avg.avg_values)
if epoch % CONFIG.test_every_epochs == 0 and epoch != 0: if epoch % c.test_every_epochs == 0 and epoch != 0:
# synthesize a part of data # synthesize a full voice
wav_path = eval_data[random.randrange(0, len(eval_data))][0] rand_idx = random.randrange(0, len(eval_data))
wav_path = eval_data[rand_idx] if not isinstance(
eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
wav = ap.load_wav(wav_path) wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav[:22000]) ground_mel = ap.melspectrogram(wav)
sample_wav = model.generate(ground_mel, sample_wav = model.generate(ground_mel,
CONFIG.batched, c.batched,
CONFIG.target_samples, c.target_samples,
CONFIG.overlap_samples, c.overlap_samples,
use_cuda use_cuda
) )
predict_mel = ap.melspectrogram(sample_wav) predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
"eval/prediction": plot_spectrogram(predict_mel.T),
}
# Sample audio # Sample audio
tb_logger.tb_eval_audios( tb_logger.tb_eval_audios(
global_step, { global_step, {
"eval/audio": sample_wav}, CONFIG.audio["sample_rate"] "eval/audio": sample_wav}, c.audio["sample_rate"]
) )
# compute spectrograms
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
"eval/prediction": plot_spectrogram(predict_mel.T)
}
tb_logger.tb_eval_figures(global_step, figures) tb_logger.tb_eval_figures(global_step, figures)
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
@ -299,53 +299,62 @@ def main(args): # pylint: disable=redefined-outer-name
global train_data, eval_data global train_data, eval_data
# setup audio processor # setup audio processor
ap = AudioProcessor(**CONFIG.audio) ap = AudioProcessor(**c.audio)
print(f" > Loading wavs from: {CONFIG.data_path}") # print(f" > Loading wavs from: {c.data_path}")
if CONFIG.feature_path is not None: # if c.feature_path is not None:
print(f" > Loading features from: {CONFIG.feature_path}") # print(f" > Loading features from: {c.feature_path}")
# eval_data, train_data = load_wav_feat_data(
# c.data_path, c.feature_path, c.eval_split_size
# )
# else:
# mel_feat_path = os.path.join(OUT_PATH, "mel")
# feat_data = find_feat_files(mel_feat_path)
# if feat_data:
# print(f" > Loading features from: {mel_feat_path}")
# eval_data, train_data = load_wav_feat_data(
# c.data_path, mel_feat_path, c.eval_split_size
# )
# else:
# print(" > No feature data found. Preprocessing...")
# # preprocessing feature data from given wav files
# preprocess_wav_files(OUT_PATH, CONFIG, ap)
# eval_data, train_data = load_wav_feat_data(
# c.data_path, mel_feat_path, c.eval_split_size
# )
print(f" > Loading wavs from: {c.data_path}")
if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}")
eval_data, train_data = load_wav_feat_data( eval_data, train_data = load_wav_feat_data(
CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size c.data_path, c.feature_path, c.eval_split_size)
)
else: else:
mel_feat_path = os.path.join(OUT_PATH, "mel") eval_data, train_data = load_wav_data(
feat_data = find_feat_files(mel_feat_path) c.data_path, c.eval_split_size)
if feat_data:
print(f" > Loading features from: {mel_feat_path}")
eval_data, train_data = load_wav_feat_data(
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
)
else:
print(" > No feature data found. Preprocessing...")
# preprocessing feature data from given wav files
preprocess_wav_files(OUT_PATH, CONFIG, ap)
eval_data, train_data = load_wav_feat_data(
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
)
# setup model # setup model
model_wavernn = setup_wavernn(CONFIG) model_wavernn = setup_wavernn(c)
# define train functions # define train functions
if CONFIG.mode == "mold": if c.mode == "mold":
criterion = discretized_mix_logistic_loss criterion = discretized_mix_logistic_loss
elif CONFIG.mode == "gauss": elif c.mode == "gauss":
criterion = gaussian_loss criterion = gaussian_loss
elif isinstance(CONFIG.mode, int): elif isinstance(c.mode, int):
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
if use_cuda: if use_cuda:
model_wavernn.cuda() model_wavernn.cuda()
if isinstance(CONFIG.mode, int): if isinstance(c.mode, int):
criterion.cuda() criterion.cuda()
optimizer = RAdam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0) optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0)
scheduler = None scheduler = None
if "lr_scheduler" in CONFIG: if "lr_scheduler" in c:
scheduler = getattr(torch.optim.lr_scheduler, CONFIG.lr_scheduler) scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
scheduler = scheduler(optimizer, **CONFIG.lr_scheduler_params) scheduler = scheduler(optimizer, **c.lr_scheduler_params)
# slow start for the first 5 epochs # slow start for the first 5 epochs
# lr_lambda = lambda epoch: min(epoch / CONFIG.warmup_steps, 1) # lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# restore any checkpoint # restore any checkpoint
@ -366,7 +375,7 @@ def main(args): # pylint: disable=redefined-outer-name
# retore only matching layers. # retore only matching layers.
print(" > Partial model initialization...") print(" > Partial model initialization...")
model_dict = model_wavernn.state_dict() model_dict = model_wavernn.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], CONFIG) model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model_wavernn.load_state_dict(model_dict) model_wavernn.load_state_dict(model_dict)
print(" > Model restored from step %d" % print(" > Model restored from step %d" %
@ -386,11 +395,10 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = float("inf") best_loss = float("inf")
global_step = args.restore_step global_step = args.restore_step
for epoch in range(0, CONFIG.epochs): for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, CONFIG.epochs) c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train( _, global_step = train(model_wavernn, optimizer,
model_wavernn, optimizer, criterion, scheduler, ap, global_step, epoch criterion, scheduler, ap, global_step, epoch)
)
eval_avg_loss_dict = evaluate( eval_avg_loss_dict = evaluate(
model_wavernn, criterion, ap, global_step, epoch) model_wavernn, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
@ -462,14 +470,14 @@ if __name__ == "__main__":
print(f" > Training continues for {args.restore_path}") print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs # setup output paths and read configs
CONFIG = load_config(args.config_path) c = load_config(args.config_path)
# check_config(c) # check_config(c)
_ = os.path.dirname(os.path.realpath(__file__)) _ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = args.continue_path OUT_PATH = args.continue_path
if args.continue_path == "": if args.continue_path == "":
OUT_PATH = create_experiment_folder( OUT_PATH = create_experiment_folder(
CONFIG.output_path, CONFIG.run_name, args.debug c.output_path, c.run_name, args.debug
) )
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios") AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
@ -483,7 +491,7 @@ if __name__ == "__main__":
new_fields["restore_path"] = args.restore_path new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch() new_fields["github_branch"] = get_git_branch()
copy_config_file( copy_config_file(
args.config_path, os.path.join(OUT_PATH, "config.json"), new_fields args.config_path, os.path.join(OUT_PATH, "c.json"), new_fields
) )
os.chmod(AUDIO_PATH, 0o775) os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775) os.chmod(OUT_PATH, 0o775)
@ -492,8 +500,7 @@ if __name__ == "__main__":
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER") tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-description", tb_logger.tb_add_text("model-description", c["run_description"], 0)
CONFIG["run_description"], 0)
try: try:
main(args) main(args)

View File

@ -2,29 +2,25 @@
"run_name": "wavernn_test", "run_name": "wavernn_test",
"run_description": "wavernn_test training", "run_description": "wavernn_test training",
// AUDIO PARAMETERS // AUDIO PARAMETERS
"audio":{ "audio": {
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
"win_length": 1024, // stft window length in ms. "win_length": 1024, // stft window length in ms.
"hop_length": 256, // stft window hop-lengh in ms. "hop_length": 256, // stft window hop-lengh in ms.
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
// Audio processing parameters // Audio processing parameters
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
// Silence trimming // Silence trimming
"do_trim_silence": false,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) "do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
"trim_db": 60, // threshold for timming silence. Set this according to your dataset. "trim_db": 60, // threshold for timming silence. Set this according to your dataset.
// MelSpectrogram parameters // MelSpectrogram parameters
"num_mels": 80, // size of the mel spec frame. "num_mels": 80, // size of the mel spec frame.
"mel_fmin": 40.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! "mel_fmin": 40.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"spec_gain": 20.0, // scaler value appplied after log transform of spectrogram. "spec_gain": 20.0, // scaler value appplied after log transform of spectrogram.
// Normalization parameters // Normalization parameters
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
"min_level_db": -100, // lower bound for normalization "min_level_db": -100, // lower bound for normalization
@ -34,40 +30,48 @@
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored "stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
}, },
// Generating / Synthesizing // Generating / Synthesizing
"batched": true, "batched": true,
"target_samples": 11000, // target number of samples to be generated in each batch entry "target_samples": 11000, // target number of samples to be generated in each batch entry
"overlap_samples": 550, // number of samples for crossfading between batches "overlap_samples": 550, // number of samples for crossfading between batches
// DISTRIBUTED TRAINING // DISTRIBUTED TRAINING
// "distributed":{ // "distributed":{
// "backend": "nccl", // "backend": "nccl",
// "url": "tcp:\/\/localhost:54321" // "url": "tcp:\/\/localhost:54321"
// }, // },
// MODEL PARAMETERS // MODEL MODE
"mode": 10, // mold [string], gauss [string], bits [int]
"mulaw": true, // apply mulaw if mode is bits
// MODEL PARAMETERS
"wavernn_model_params": {
"rnn_dims": 512,
"fc_dims": 512,
"compute_dims": 128,
"res_out_dims": 128,
"num_res_blocks": 10,
"use_aux_net": true, "use_aux_net": true,
"use_upsample_net": true, "use_upsample_net": true,
"upsample_factors": [4, 8, 8], // this needs to correctly factorise hop_length "upsample_factors": [4, 8, 8] // this needs to correctly factorise hop_length
},
// DATASET
//"use_gta": true, // use computed gta features from the tts model
"data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech", // path containing training wav files
"feature_path": null, // path containing computed features from wav files if null compute them
"seq_len": 1280, // has to be devideable by hop_length "seq_len": 1280, // has to be devideable by hop_length
"mode": "mold", // mold [string], gauss [string], bits [int]
"mulaw": false, // apply mulaw if mode is bits
"padding": 2, // pad the input for resnet to see wider input length "padding": 2, // pad the input for resnet to see wider input length
// DATASET // TRAINING
//"use_gta": true, // use computed gta features from the tts model "batch_size": 64, // Batch size for training.
"data_path": "path/to/wav/files", // path containing training wav files
"feature_path": null, // path containing computed features from wav files if null compute them
// TRAINING
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"epochs": 10000, // total number of epochs to train. "epochs": 10000, // total number of epochs to train.
// VALIDATION // VALIDATION
"run_eval": true, "run_eval": true,
"test_every_epochs": 10, // Test after set number of epochs (Test every 20 epochs for example) "test_every_epochs": 10, // Test after set number of epochs (Test every 10 epochs for example)
// OPTIMIZER // OPTIMIZER
"grad_clip": 4, // apply gradient clipping if > 0 "grad_clip": 4, // apply gradient clipping if > 0
"lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate "lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
"lr_scheduler_params": { "lr_scheduler_params": {
@ -76,19 +80,18 @@
}, },
"lr": 1e-4, // initial learning rate "lr": 1e-4, // initial learning rate
// TENSORBOARD and LOGGING // TENSORBOARD and LOGGING
"print_step": 25, // Number of steps to log traning on console. "print_step": 25, // Number of steps to log traning on console.
"print_eval": false, // If True, it prints loss values for each step in eval run. "print_eval": false, // If True, it prints loss values for each step in eval run.
"save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints. "save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
"checkpoint": true, // If true, it saves checkpoints per "save_step" "checkpoint": true, // If true, it saves checkpoints per "save_step"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
// DATA LOADING // DATA LOADING
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 4, // number of evaluation data loader processes. "num_val_loader_workers": 4, // number of evaluation data loader processes.
"eval_split_size": 50, // number of samples for testing "eval_split_size": 50, // number of samples for testing
// PATHS // PATHS
"output_path": "output/training/path" "output_path": "output/training/path"
} }

View File

@ -1,11 +1,13 @@
import torch import torch
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from multiprocessing import Manager
class WaveRNNDataset(Dataset): class WaveRNNDataset(Dataset):
""" """
WaveRNN Dataset searchs for all the wav files under root path. WaveRNN Dataset searchs for all the wav files under root path
and converts them to acoustic features on the fly.
""" """
def __init__(self, def __init__(self,
@ -15,16 +17,19 @@ class WaveRNNDataset(Dataset):
hop_len, hop_len,
pad, pad,
mode, mode,
mulaw,
is_training=True, is_training=True,
verbose=False, verbose=False,
): ):
self.ap = ap self.ap = ap
self.compute_feat = not isinstance(items[0], (tuple, list))
self.item_list = items self.item_list = items
self.seq_len = seq_len self.seq_len = seq_len
self.hop_len = hop_len self.hop_len = hop_len
self.pad = pad self.pad = pad
self.mode = mode self.mode = mode
self.mulaw = mulaw
self.is_training = is_training self.is_training = is_training
self.verbose = verbose self.verbose = verbose
@ -36,22 +41,47 @@ class WaveRNNDataset(Dataset):
return item return item
def load_item(self, index): def load_item(self, index):
"""
load (audio, feat) couple if feature_path is set
else compute it on the fly
"""
if self.compute_feat:
wavpath = self.item_list[index]
audio = self.ap.load_wav(wavpath)
mel = self.ap.melspectrogram(audio)
if mel.shape[-1] < 5:
print(" [!] Instance is too short! : {}".format(wavpath))
self.item_list[index] = self.item_list[index + 1]
audio = self.ap.load_wav(wavpath)
mel = self.ap.melspectrogram(audio)
if self.mode in ["gauss", "mold"]:
x_input = audio
elif isinstance(self.mode, int):
x_input = (self.ap.mulaw_encode(audio, qc=self.mode)
if self.mulaw else self.ap.quantize(audio, bits=self.mode))
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
else:
wavpath, feat_path = self.item_list[index] wavpath, feat_path = self.item_list[index]
m = np.load(feat_path.replace("/quant/", "/mel/")) mel = np.load(feat_path.replace("/quant/", "/mel/"))
# x = self.wav_cache[index]
if m.shape[-1] < 5: if mel.shape[-1] < 5:
print(" [!] Instance is too short! : {}".format(wavpath)) print(" [!] Instance is too short! : {}".format(wavpath))
self.item_list[index] = self.item_list[index + 1] self.item_list[index] = self.item_list[index + 1]
feat_path = self.item_list[index] feat_path = self.item_list[index]
m = np.load(feat_path.replace("/quant/", "/mel/")) mel = np.load(feat_path.replace("/quant/", "/mel/"))
if self.mode in ["gauss", "mold"]: if self.mode in ["gauss", "mold"]:
# x = np.load(feat_path.replace("/mel/", "/quant/")) x_input = self.ap.load_wav(wavpath)
x = self.ap.load_wav(wavpath)
elif isinstance(self.mode, int): elif isinstance(self.mode, int):
x = np.load(feat_path.replace("/mel/", "/quant/")) x_input = np.load(feat_path.replace("/mel/", "/quant/"))
else: else:
raise RuntimeError("Unknown dataset mode - ", self.mode) raise RuntimeError("Unknown dataset mode - ", self.mode)
return m, x
return mel, x_input
def collate(self, batch): def collate(self, batch):
mel_win = self.seq_len // self.hop_len + 2 * self.pad mel_win = self.seq_len // self.hop_len + 2 * self.pad
@ -79,10 +109,8 @@ class WaveRNNDataset(Dataset):
elif isinstance(self.mode, int): elif isinstance(self.mode, int):
coarse = np.stack(coarse).astype(np.int64) coarse = np.stack(coarse).astype(np.int64)
coarse = torch.LongTensor(coarse) coarse = torch.LongTensor(coarse)
x_input = ( x_input = (2 * coarse[:, : self.seq_len].float() /
2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0)
(2 ** self.mode - 1.0) - 1.0
)
y_coarse = coarse[:, 1:] y_coarse = coarse[:, 1:]
mels = torch.FloatTensor(mels) mels = torch.FloatTensor(mels)
return x_input, mels, y_coarse return x_input, mels, y_coarse

View File

@ -36,14 +36,14 @@ class ResBlock(nn.Module):
class MelResNet(nn.Module): class MelResNet(nn.Module):
def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad): def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad):
super().__init__() super().__init__()
k_size = pad * 2 + 1 k_size = pad * 2 + 1
self.conv_in = nn.Conv1d( self.conv_in = nn.Conv1d(
in_dims, compute_dims, kernel_size=k_size, bias=False) in_dims, compute_dims, kernel_size=k_size, bias=False)
self.batch_norm = nn.BatchNorm1d(compute_dims) self.batch_norm = nn.BatchNorm1d(compute_dims)
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
for _ in range(res_blocks): for _ in range(num_res_blocks):
self.layers.append(ResBlock(compute_dims)) self.layers.append(ResBlock(compute_dims))
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
@ -76,7 +76,7 @@ class UpsampleNetwork(nn.Module):
feat_dims, feat_dims,
upsample_scales, upsample_scales,
compute_dims, compute_dims,
res_blocks, num_res_blocks,
res_out_dims, res_out_dims,
pad, pad,
use_aux_net, use_aux_net,
@ -87,7 +87,7 @@ class UpsampleNetwork(nn.Module):
self.use_aux_net = use_aux_net self.use_aux_net = use_aux_net
if use_aux_net: if use_aux_net:
self.resnet = MelResNet( self.resnet = MelResNet(
res_blocks, feat_dims, compute_dims, res_out_dims, pad num_res_blocks, feat_dims, compute_dims, res_out_dims, pad
) )
self.resnet_stretch = Stretch2d(self.total_scale, 1) self.resnet_stretch = Stretch2d(self.total_scale, 1)
self.up_layers = nn.ModuleList() self.up_layers = nn.ModuleList()
@ -118,14 +118,14 @@ class UpsampleNetwork(nn.Module):
class Upsample(nn.Module): class Upsample(nn.Module):
def __init__( def __init__(
self, scale, pad, res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net
): ):
super().__init__() super().__init__()
self.scale = scale self.scale = scale
self.pad = pad self.pad = pad
self.indent = pad * scale self.indent = pad * scale
self.use_aux_net = use_aux_net self.use_aux_net = use_aux_net
self.resnet = MelResNet(res_blocks, feat_dims, self.resnet = MelResNet(num_res_blocks, feat_dims,
compute_dims, res_out_dims, pad) compute_dims, res_out_dims, pad)
def forward(self, m): def forward(self, m):
@ -147,8 +147,7 @@ class Upsample(nn.Module):
class WaveRNN(nn.Module): class WaveRNN(nn.Module):
def __init__( def __init__(self,
self,
rnn_dims, rnn_dims,
fc_dims, fc_dims,
mode, mode,
@ -160,7 +159,7 @@ class WaveRNN(nn.Module):
feat_dims, feat_dims,
compute_dims, compute_dims,
res_out_dims, res_out_dims,
res_blocks, num_res_blocks,
hop_length, hop_length,
sample_rate, sample_rate,
): ):
@ -177,7 +176,7 @@ class WaveRNN(nn.Module):
elif self.mode == "gauss": elif self.mode == "gauss":
self.n_classes = 2 self.n_classes = 2
else: else:
raise RuntimeError(" > Unknown training mode") raise RuntimeError("Unknown model mode value - ", self.mode)
self.rnn_dims = rnn_dims self.rnn_dims = rnn_dims
self.aux_dims = res_out_dims // 4 self.aux_dims = res_out_dims // 4
@ -192,7 +191,7 @@ class WaveRNN(nn.Module):
feat_dims, feat_dims,
upsample_factors, upsample_factors,
compute_dims, compute_dims,
res_blocks, num_res_blocks,
res_out_dims, res_out_dims,
pad, pad,
use_aux_net, use_aux_net,
@ -201,7 +200,7 @@ class WaveRNN(nn.Module):
self.upsample = Upsample( self.upsample = Upsample(
hop_length, hop_length,
pad, pad,
res_blocks, num_res_blocks,
feat_dims, feat_dims,
compute_dims, compute_dims,
res_out_dims, res_out_dims,
@ -260,7 +259,7 @@ class WaveRNN(nn.Module):
x = F.relu(self.fc2(x)) x = F.relu(self.fc2(x))
return self.fc3(x) return self.fc3(x)
def generate(self, mels, batched, target, overlap, use_cuda): def generate(self, mels, batched, target, overlap, use_cuda=False):
self.eval() self.eval()
device = 'cuda' if use_cuda else 'cpu' device = 'cuda' if use_cuda else 'cpu'
@ -360,6 +359,8 @@ class WaveRNN(nn.Module):
# Fade-out at the end to avoid signal cutting out suddenly # Fade-out at the end to avoid signal cutting out suddenly
fade_out = np.linspace(1, 0, 20 * self.hop_length) fade_out = np.linspace(1, 0, 20 * self.hop_length)
output = output[:wave_len] output = output[:wave_len]
if wave_len > len(fade_out):
output[-20 * self.hop_length:] *= fade_out output[-20 * self.hop_length:] *= fade_out
self.train() self.train()
@ -405,7 +406,8 @@ class WaveRNN(nn.Module):
padding = target + 2 * overlap - remaining padding = target + 2 * overlap - remaining
x = self.pad_tensor(x, padding, side="after") x = self.pad_tensor(x, padding, side="after")
folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device) folded = torch.zeros(num_folds, target + 2 *
overlap, features).to(x.device)
# Get the values for the folded tensor # Get the values for the folded tensor
for i in range(num_folds): for i in range(num_folds):