diff --git a/.travis/script b/.travis/script index 0c24a221..0860f9cf 100755 --- a/.travis/script +++ b/.travis/script @@ -17,5 +17,6 @@ fi if [[ "$TEST_SUITE" == "testscripts" ]]; then # test model training scripts ./tests/test_tts_train.sh - ./tests/test_vocoder_train.sh + ./tests/test_vocoder_gan_train.sh + ./tests/test_vocoder_wavernn_train.sh fi diff --git a/README.md b/README.md index 31348f28..53ec1aad 100644 --- a/README.md +++ b/README.md @@ -11,19 +11,22 @@
+Mozilla TTS is a deep learning based Text2Speech project, low in cost and high in quality. + This project is a part of [Mozilla Common Voice](https://voice.mozilla.org/en). -Mozilla TTS aims a deep learning based Text2Speech engine, low in cost and high in quality. +English Voice Samples: https://erogol.github.io/ddc-samples/ -You can check some of synthesized voice samples from [here](https://erogol.github.io/ddc-samples/). +TTS training recipes: https://github.com/erogol/TTS_recipes -If you are new, you can also find [here](http://www.erogol.com/text-speech-deep-learning-architectures/) a brief post about some of TTS architectures and [here](https://github.com/erogol/TTS-papers) list of up-to-date research papers. +TTS paper collection: https://github.com/erogol/TTS-papers [![](https://sourcerer.io/fame/erogol/erogol/TTS/images/0)](https://sourcerer.io/fame/erogol/erogol/TTS/links/0)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/1)](https://sourcerer.io/fame/erogol/erogol/TTS/links/1)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/2)](https://sourcerer.io/fame/erogol/erogol/TTS/links/2)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/3)](https://sourcerer.io/fame/erogol/erogol/TTS/links/3)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/4)](https://sourcerer.io/fame/erogol/erogol/TTS/links/4)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/5)](https://sourcerer.io/fame/erogol/erogol/TTS/links/5)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/6)](https://sourcerer.io/fame/erogol/erogol/TTS/links/6)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/7)](https://sourcerer.io/fame/erogol/erogol/TTS/links/7) ## TTS Performance -

+

+"Mozilla*" and "Judy*" are our models. [Details...](https://github.com/mozilla/TTS/wiki/Mean-Opinion-Score-Results) ## Provided Models and Methods diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py index 1c6ef94d..ca089d3e 100755 --- a/TTS/bin/compute_statistics.py +++ b/TTS/bin/compute_statistics.py @@ -11,6 +11,7 @@ from TTS.tts.datasets.preprocess import load_meta_data from TTS.utils.io import load_config from TTS.utils.audio import AudioProcessor + def main(): """Run preprocessing process.""" parser = argparse.ArgumentParser( diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_gan_vocoder.py similarity index 98% rename from TTS/bin/train_vocoder.py rename to TTS/bin/train_gan_vocoder.py index b51a55a3..12edf048 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_gan_vocoder.py @@ -326,7 +326,6 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) y_hat = model_G.pqmf_synthesis(y_hat) y_G_sub = model_G.pqmf_analysis(y_G) - scores_fake, feats_fake, feats_real = None, None, None if global_step > c.steps_to_start_discriminator: @@ -403,7 +402,6 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) else: loss_dict[key] = value.item() - step_time = time.time() - start_time epoch_time += step_time @@ -443,7 +441,8 @@ def main(args): # pylint: disable=redefined-outer-name print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") - eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) + eval_data, train_data = load_wav_feat_data( + c.data_path, c.feature_path, c.eval_split_size) else: eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) @@ -470,10 +469,12 @@ def main(args): # pylint: disable=redefined-outer-name scheduler_disc = None if 'lr_scheduler_gen' in c: scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) - scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) + scheduler_gen = scheduler_gen( + optimizer_gen, **c.lr_scheduler_gen_params) if 'lr_scheduler_disc' in c: scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) - scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) + scheduler_disc = scheduler_disc( + optimizer_disc, **c.lr_scheduler_disc_params) # setup criterion criterion_gen = GeneratorLoss(c) @@ -572,8 +573,7 @@ if __name__ == '__main__': parser.add_argument( '--continue_path', type=str, - help= - 'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', + help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', default='', required='--config_path' not in sys.argv) parser.add_argument( diff --git a/TTS/bin/train_wavernn_vocoder.py b/TTS/bin/train_wavernn_vocoder.py new file mode 100644 index 00000000..61664a65 --- /dev/null +++ b/TTS/bin/train_wavernn_vocoder.py @@ -0,0 +1,515 @@ +import argparse +import os +import sys +import traceback +import time +import glob +import random + +import torch +from torch.utils.data import DataLoader + +# from torch.utils.data.distributed import DistributedSampler + +from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.audio import AudioProcessor +from TTS.utils.radam import RAdam +from TTS.utils.io import copy_config_file, load_config +from TTS.utils.training import setup_torch_training_env +from TTS.utils.console_logger import ConsoleLogger +from TTS.utils.tensorboard_logger import TensorboardLogger +from TTS.utils.generic_utils import ( + KeepAverage, + count_parameters, + create_experiment_folder, + get_git_branch, + remove_experiment_folder, + set_init_dict, +) +from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset +from TTS.vocoder.datasets.preprocess import ( + load_wav_data, + load_wav_feat_data +) +from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss +from TTS.vocoder.utils.generic_utils import setup_wavernn +from TTS.vocoder.utils.io import save_best_model, save_checkpoint + + +use_cuda, num_gpus = setup_torch_training_env(True, True) + + +def setup_loader(ap, is_val=False, verbose=False): + if is_val and not c.run_eval: + loader = None + else: + dataset = WaveRNNDataset(ap=ap, + items=eval_data if is_val else train_data, + seq_len=c.seq_len, + hop_len=ap.hop_length, + pad=c.padding, + mode=c.mode, + mulaw=c.mulaw, + is_training=not is_val, + verbose=verbose, + ) + # sampler = DistributedSampler(dataset) if num_gpus > 1 else None + loader = DataLoader(dataset, + shuffle=True, + collate_fn=dataset.collate, + batch_size=c.batch_size, + num_workers=c.num_val_loader_workers + if is_val + else c.num_loader_workers, + pin_memory=True, + ) + return loader + + +def format_data(data): + # setup input data + x_input = data[0] + mels = data[1] + y_coarse = data[2] + + # dispatch data to GPU + if use_cuda: + x_input = x_input.cuda(non_blocking=True) + mels = mels.cuda(non_blocking=True) + y_coarse = y_coarse.cuda(non_blocking=True) + + return x_input, mels, y_coarse + + +def train(model, optimizer, criterion, scheduler, ap, global_step, epoch): + # create train loader + data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) + model.train() + epoch_time = 0 + keep_avg = KeepAverage() + if use_cuda: + batch_n_iter = int(len(data_loader.dataset) / + (c.batch_size * num_gpus)) + else: + batch_n_iter = int(len(data_loader.dataset) / c.batch_size) + end_time = time.time() + c_logger.print_train_start() + # train loop + print(" > Training", flush=True) + for num_iter, data in enumerate(data_loader): + start_time = time.time() + x_input, mels, y_coarse = format_data(data) + loader_time = time.time() - end_time + global_step += 1 + + y_hat = model(x_input, mels) + + if isinstance(model.mode, int): + y_hat = y_hat.transpose(1, 2).unsqueeze(-1) + else: + y_coarse = y_coarse.float() + y_coarse = y_coarse.unsqueeze(-1) + + # compute losses + loss = criterion(y_hat, y_coarse) + if loss.item() is None: + raise RuntimeError(" [!] None loss. Exiting ...") + optimizer.zero_grad() + loss.backward() + if c.grad_clip > 0: + torch.nn.utils.clip_grad_norm_( + model.parameters(), c.grad_clip) + optimizer.step() + + if scheduler is not None: + scheduler.step() + + # get the current learning rate + cur_lr = list(optimizer.param_groups)[0]["lr"] + + step_time = time.time() - start_time + epoch_time += step_time + + update_train_values = dict() + loss_dict = dict() + loss_dict["model_loss"] = loss.item() + for key, value in loss_dict.items(): + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time + keep_avg.update_values(update_train_values) + + # print training stats + if global_step % c.print_step == 0: + log_dict = {"step_time": [step_time, 2], + "loader_time": [loader_time, 4], + "current_lr": cur_lr, + } + c_logger.print_train_step(batch_n_iter, + num_iter, + global_step, + log_dict, + loss_dict, + keep_avg.avg_values, + ) + + # plot step stats + if global_step % 10 == 0: + iter_stats = {"lr": cur_lr, "step_time": step_time} + iter_stats.update(loss_dict) + tb_logger.tb_train_iter_stats(global_step, iter_stats) + + # save checkpoint + if global_step % c.save_step == 0: + if c.checkpoint: + # save model + save_checkpoint(model, + optimizer, + scheduler, + None, + None, + None, + global_step, + epoch, + OUT_PATH, + model_losses=loss_dict, + ) + + # synthesize a full voice + rand_idx = random.randrange(0, len(train_data)) + wav_path = train_data[rand_idx] if not isinstance( + train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0] + wav = ap.load_wav(wav_path) + ground_mel = ap.melspectrogram(wav) + sample_wav = model.generate(ground_mel, + c.batched, + c.target_samples, + c.overlap_samples, + use_cuda + ) + predict_mel = ap.melspectrogram(sample_wav) + + # compute spectrograms + figures = {"train/ground_truth": plot_spectrogram(ground_mel.T), + "train/prediction": plot_spectrogram(predict_mel.T) + } + tb_logger.tb_train_figures(global_step, figures) + + # Sample audio + tb_logger.tb_train_audios( + global_step, { + "train/audio": sample_wav}, c.audio["sample_rate"] + ) + end_time = time.time() + + # print epoch stats + c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) + + # Plot Training Epoch Stats + epoch_stats = {"epoch_time": epoch_time} + epoch_stats.update(keep_avg.avg_values) + tb_logger.tb_train_epoch_stats(global_step, epoch_stats) + # TODO: plot model stats + # if c.tb_model_param_stats: + # tb_logger.tb_model_weights(model, global_step) + return keep_avg.avg_values, global_step + + +@torch.no_grad() +def evaluate(model, criterion, ap, global_step, epoch): + # create train loader + data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) + model.eval() + epoch_time = 0 + keep_avg = KeepAverage() + end_time = time.time() + c_logger.print_eval_start() + with torch.no_grad(): + for num_iter, data in enumerate(data_loader): + start_time = time.time() + # format data + x_input, mels, y_coarse = format_data(data) + loader_time = time.time() - end_time + global_step += 1 + + y_hat = model(x_input, mels) + if isinstance(model.mode, int): + y_hat = y_hat.transpose(1, 2).unsqueeze(-1) + else: + y_coarse = y_coarse.float() + y_coarse = y_coarse.unsqueeze(-1) + loss = criterion(y_hat, y_coarse) + # Compute avg loss + # if num_gpus > 1: + # loss = reduce_tensor(loss.data, num_gpus) + loss_dict = dict() + loss_dict["model_loss"] = loss.item() + + step_time = time.time() - start_time + epoch_time += step_time + + # update avg stats + update_eval_values = dict() + for key, value in loss_dict.items(): + update_eval_values["avg_" + key] = value + update_eval_values["avg_loader_time"] = loader_time + update_eval_values["avg_step_time"] = step_time + keep_avg.update_values(update_eval_values) + + # print eval stats + if c.print_eval: + c_logger.print_eval_step( + num_iter, loss_dict, keep_avg.avg_values) + + if epoch % c.test_every_epochs == 0 and epoch != 0: + # synthesize a full voice + rand_idx = random.randrange(0, len(eval_data)) + wav_path = eval_data[rand_idx] if not isinstance( + eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0] + wav = ap.load_wav(wav_path) + ground_mel = ap.melspectrogram(wav) + sample_wav = model.generate(ground_mel, + c.batched, + c.target_samples, + c.overlap_samples, + use_cuda + ) + predict_mel = ap.melspectrogram(sample_wav) + + # Sample audio + tb_logger.tb_eval_audios( + global_step, { + "eval/audio": sample_wav}, c.audio["sample_rate"] + ) + + # compute spectrograms + figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T), + "eval/prediction": plot_spectrogram(predict_mel.T) + } + tb_logger.tb_eval_figures(global_step, figures) + + tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) + return keep_avg.avg_values + + +# FIXME: move args definition/parsing inside of main? +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global train_data, eval_data + + # setup audio processor + ap = AudioProcessor(**c.audio) + + # print(f" > Loading wavs from: {c.data_path}") + # if c.feature_path is not None: + # print(f" > Loading features from: {c.feature_path}") + # eval_data, train_data = load_wav_feat_data( + # c.data_path, c.feature_path, c.eval_split_size + # ) + # else: + # mel_feat_path = os.path.join(OUT_PATH, "mel") + # feat_data = find_feat_files(mel_feat_path) + # if feat_data: + # print(f" > Loading features from: {mel_feat_path}") + # eval_data, train_data = load_wav_feat_data( + # c.data_path, mel_feat_path, c.eval_split_size + # ) + # else: + # print(" > No feature data found. Preprocessing...") + # # preprocessing feature data from given wav files + # preprocess_wav_files(OUT_PATH, CONFIG, ap) + # eval_data, train_data = load_wav_feat_data( + # c.data_path, mel_feat_path, c.eval_split_size + # ) + + print(f" > Loading wavs from: {c.data_path}") + if c.feature_path is not None: + print(f" > Loading features from: {c.feature_path}") + eval_data, train_data = load_wav_feat_data( + c.data_path, c.feature_path, c.eval_split_size) + else: + eval_data, train_data = load_wav_data( + c.data_path, c.eval_split_size) + # setup model + model_wavernn = setup_wavernn(c) + + # define train functions + if c.mode == "mold": + criterion = discretized_mix_logistic_loss + elif c.mode == "gauss": + criterion = gaussian_loss + elif isinstance(c.mode, int): + criterion = torch.nn.CrossEntropyLoss() + + if use_cuda: + model_wavernn.cuda() + if isinstance(c.mode, int): + criterion.cuda() + + optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0) + + scheduler = None + if "lr_scheduler" in c: + scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler) + scheduler = scheduler(optimizer, **c.lr_scheduler_params) + # slow start for the first 5 epochs + # lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1) + # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + # restore any checkpoint + if args.restore_path: + checkpoint = torch.load(args.restore_path, map_location="cpu") + try: + print(" > Restoring Model...") + model_wavernn.load_state_dict(checkpoint["model"]) + print(" > Restoring Optimizer...") + optimizer.load_state_dict(checkpoint["optimizer"]) + if "scheduler" in checkpoint: + print(" > Restoring Generator LR Scheduler...") + scheduler.load_state_dict(checkpoint["scheduler"]) + scheduler.optimizer = optimizer + # TODO: fix resetting restored optimizer lr + # optimizer.load_state_dict(checkpoint["optimizer"]) + except RuntimeError: + # retore only matching layers. + print(" > Partial model initialization...") + model_dict = model_wavernn.state_dict() + model_dict = set_init_dict(model_dict, checkpoint["model"], c) + model_wavernn.load_state_dict(model_dict) + + print(" > Model restored from step %d" % + checkpoint["step"], flush=True) + args.restore_step = checkpoint["step"] + else: + args.restore_step = 0 + + # DISTRIBUTED + # if num_gpus > 1: + # model = apply_gradient_allreduce(model) + + num_parameters = count_parameters(model_wavernn) + print(" > Model has {} parameters".format(num_parameters), flush=True) + + if "best_loss" not in locals(): + best_loss = float("inf") + + global_step = args.restore_step + for epoch in range(0, c.epochs): + c_logger.print_epoch_start(epoch, c.epochs) + _, global_step = train(model_wavernn, optimizer, + criterion, scheduler, ap, global_step, epoch) + eval_avg_loss_dict = evaluate( + model_wavernn, criterion, ap, global_step, epoch) + c_logger.print_epoch_end(epoch, eval_avg_loss_dict) + target_loss = eval_avg_loss_dict["avg_model_loss"] + best_loss = save_best_model( + target_loss, + best_loss, + model_wavernn, + optimizer, + scheduler, + None, + None, + None, + global_step, + epoch, + OUT_PATH, + model_losses=eval_avg_loss_dict, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--continue_path", + type=str, + help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', + default="", + required="--config_path" not in sys.argv, + ) + parser.add_argument( + "--restore_path", + type=str, + help="Model file to be restored. Use to finetune a model.", + default="", + ) + parser.add_argument( + "--config_path", + type=str, + help="Path to config file for training.", + required="--continue_path" not in sys.argv, + ) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="Do not verify commit integrity to run training.", + ) + + # DISTRUBUTED + parser.add_argument( + "--rank", + type=int, + default=0, + help="DISTRIBUTED: process rank for distributed training.", + ) + parser.add_argument( + "--group_id", type=str, default="", help="DISTRIBUTED: process group id." + ) + args = parser.parse_args() + + if args.continue_path != "": + args.output_path = args.continue_path + args.config_path = os.path.join(args.continue_path, "config.json") + list_of_files = glob.glob( + args.continue_path + "/*.pth.tar" + ) # * means all if need specific format then *.csv + latest_model_file = max(list_of_files, key=os.path.getctime) + args.restore_path = latest_model_file + print(f" > Training continues for {args.restore_path}") + + # setup output paths and read configs + c = load_config(args.config_path) + # check_config(c) + _ = os.path.dirname(os.path.realpath(__file__)) + + OUT_PATH = args.continue_path + if args.continue_path == "": + OUT_PATH = create_experiment_folder( + c.output_path, c.run_name, args.debug + ) + + AUDIO_PATH = os.path.join(OUT_PATH, "test_audios") + + c_logger = ConsoleLogger() + + if args.rank == 0: + os.makedirs(AUDIO_PATH, exist_ok=True) + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + copy_config_file( + args.config_path, os.path.join(OUT_PATH, "c.json"), new_fields + ) + os.chmod(AUDIO_PATH, 0o775) + os.chmod(OUT_PATH, 0o775) + + LOG_DIR = OUT_PATH + tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER") + + # write model desc to tensorboard + tb_logger.tb_add_text("model-description", c["run_description"], 0) + + try: + main(args) + except KeyboardInterrupt: + remove_experiment_folder(OUT_PATH) + try: + sys.exit(0) + except SystemExit: + os._exit(0) # pylint: disable=protected-access + except Exception: # pylint: disable=broad-except + remove_experiment_folder(OUT_PATH) + traceback.print_exc() + sys.exit(1) diff --git a/TTS/vocoder/configs/wavernn_config.json b/TTS/vocoder/configs/wavernn_config.json new file mode 100644 index 00000000..9a9fbdae --- /dev/null +++ b/TTS/vocoder/configs/wavernn_config.json @@ -0,0 +1,97 @@ +{ + "run_name": "wavernn_test", + "run_description": "wavernn_test training", + +// AUDIO PARAMETERS + "audio": { + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + // Audio processing parameters + "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + // Silence trimming + "do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60, // threshold for timming silence. Set this according to your dataset. + // MelSpectrogram parameters + "num_mels": 80, // size of the mel spec frame. + "mel_fmin": 40.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! + "spec_gain": 20.0, // scaler value appplied after log transform of spectrogram. + // Normalization parameters + "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. + "min_level_db": -100, // lower bound for normalization + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored + }, + +// Generating / Synthesizing + "batched": true, + "target_samples": 11000, // target number of samples to be generated in each batch entry + "overlap_samples": 550, // number of samples for crossfading between batches + // DISTRIBUTED TRAINING + // "distributed":{ + // "backend": "nccl", + // "url": "tcp:\/\/localhost:54321" + // }, + +// MODEL MODE + "mode": 10, // mold [string], gauss [string], bits [int] + "mulaw": true, // apply mulaw if mode is bits + +// MODEL PARAMETERS + "wavernn_model_params": { + "rnn_dims": 512, + "fc_dims": 512, + "compute_dims": 128, + "res_out_dims": 128, + "num_res_blocks": 10, + "use_aux_net": true, + "use_upsample_net": true, + "upsample_factors": [4, 8, 8] // this needs to correctly factorise hop_length + }, + +// DATASET + //"use_gta": true, // use computed gta features from the tts model + "data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech", // path containing training wav files + "feature_path": null, // path containing computed features from wav files if null compute them + "seq_len": 1280, // has to be devideable by hop_length + "padding": 2, // pad the input for resnet to see wider input length + +// TRAINING + "batch_size": 64, // Batch size for training. + "epochs": 10000, // total number of epochs to train. + +// VALIDATION + "run_eval": true, + "test_every_epochs": 10, // Test after set number of epochs (Test every 10 epochs for example) + +// OPTIMIZER + "grad_clip": 4, // apply gradient clipping if > 0 + "lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate + "lr_scheduler_params": { + "gamma": 0.5, + "milestones": [200000, 400000, 600000] + }, + "lr": 1e-4, // initial learning rate + +// TENSORBOARD and LOGGING + "print_step": 25, // Number of steps to log traning on console. + "print_eval": false, // If True, it prints loss values for each step in eval run. + "save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + +// DATA LOADING + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 4, // number of evaluation data loader processes. + "eval_split_size": 50, // number of samples for testing + +// PATHS + "output_path": "output/training/path" +} diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py index be60c13a..afea45fd 100644 --- a/TTS/vocoder/datasets/preprocess.py +++ b/TTS/vocoder/datasets/preprocess.py @@ -1,17 +1,38 @@ import glob import os from pathlib import Path +from tqdm import tqdm import numpy as np +def preprocess_wav_files(out_path, config, ap): + os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) + os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) + wav_files = find_wav_files(config.data_path) + for path in tqdm(wav_files): + wav_name = Path(path).stem + quant_path = os.path.join(out_path, "quant", wav_name + ".npy") + mel_path = os.path.join(out_path, "mel", wav_name + ".npy") + y = ap.load_wav(path) + mel = ap.melspectrogram(y) + np.save(mel_path, mel) + if isinstance(config.mode, int): + quant = ( + ap.mulaw_encode(y, qc=config.mode) + if config.mulaw + else ap.quantize(y, bits=config.mode) + ) + np.save(quant_path, quant) + + def find_wav_files(data_path): - wav_paths = glob.glob(os.path.join(data_path, '**', '*.wav'), recursive=True) + wav_paths = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) return wav_paths def find_feat_files(data_path): - feat_paths = glob.glob(os.path.join(data_path, '**', '*.npy'), recursive=True) + feat_paths = glob.glob(os.path.join(data_path, "**", "*.npy"), recursive=True) return feat_paths @@ -23,8 +44,12 @@ def load_wav_data(data_path, eval_split_size): def load_wav_feat_data(data_path, feat_path, eval_split_size): - wav_paths = sorted(find_wav_files(data_path)) - feat_paths = sorted(find_feat_files(feat_path)) + wav_paths = find_wav_files(data_path) + feat_paths = find_feat_files(feat_path) + + wav_paths.sort(key=lambda x: Path(x).stem) + feat_paths.sort(key=lambda x: Path(x).stem) + assert len(wav_paths) == len(feat_paths) for wav, feat in zip(wav_paths, feat_paths): wav_name = Path(wav).stem diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py new file mode 100644 index 00000000..9c1ded96 --- /dev/null +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -0,0 +1,115 @@ +import torch +import numpy as np +from torch.utils.data import Dataset + + +class WaveRNNDataset(Dataset): + """ + WaveRNN Dataset searchs for all the wav files under root path + and converts them to acoustic features on the fly. + """ + + def __init__(self, + ap, + items, + seq_len, + hop_len, + pad, + mode, + mulaw, + is_training=True, + verbose=False, + ): + + self.ap = ap + self.compute_feat = not isinstance(items[0], (tuple, list)) + self.item_list = items + self.seq_len = seq_len + self.hop_len = hop_len + self.pad = pad + self.mode = mode + self.mulaw = mulaw + self.is_training = is_training + self.verbose = verbose + + def __len__(self): + return len(self.item_list) + + def __getitem__(self, index): + item = self.load_item(index) + return item + + def load_item(self, index): + """ + load (audio, feat) couple if feature_path is set + else compute it on the fly + """ + if self.compute_feat: + + wavpath = self.item_list[index] + audio = self.ap.load_wav(wavpath) + mel = self.ap.melspectrogram(audio) + + if mel.shape[-1] < 5: + print(" [!] Instance is too short! : {}".format(wavpath)) + self.item_list[index] = self.item_list[index + 1] + audio = self.ap.load_wav(wavpath) + mel = self.ap.melspectrogram(audio) + if self.mode in ["gauss", "mold"]: + x_input = audio + elif isinstance(self.mode, int): + x_input = (self.ap.mulaw_encode(audio, qc=self.mode) + if self.mulaw else self.ap.quantize(audio, bits=self.mode)) + else: + raise RuntimeError("Unknown dataset mode - ", self.mode) + + else: + + wavpath, feat_path = self.item_list[index] + mel = np.load(feat_path.replace("/quant/", "/mel/")) + + if mel.shape[-1] < 5: + print(" [!] Instance is too short! : {}".format(wavpath)) + self.item_list[index] = self.item_list[index + 1] + feat_path = self.item_list[index] + mel = np.load(feat_path.replace("/quant/", "/mel/")) + if self.mode in ["gauss", "mold"]: + x_input = self.ap.load_wav(wavpath) + elif isinstance(self.mode, int): + x_input = np.load(feat_path.replace("/mel/", "/quant/")) + else: + raise RuntimeError("Unknown dataset mode - ", self.mode) + + return mel, x_input + + def collate(self, batch): + mel_win = self.seq_len // self.hop_len + 2 * self.pad + max_offsets = [x[0].shape[-1] - + (mel_win + 2 * self.pad) for x in batch] + mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] + sig_offsets = [(offset + self.pad) * + self.hop_len for offset in mel_offsets] + + mels = [ + x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win] + for i, x in enumerate(batch) + ] + + coarse = [ + x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1] + for i, x in enumerate(batch) + ] + + mels = np.stack(mels).astype(np.float32) + if self.mode in ["gauss", "mold"]: + coarse = np.stack(coarse).astype(np.float32) + coarse = torch.FloatTensor(coarse) + x_input = coarse[:, : self.seq_len] + elif isinstance(self.mode, int): + coarse = np.stack(coarse).astype(np.int64) + coarse = torch.LongTensor(coarse) + x_input = (2 * coarse[:, : self.seq_len].float() / + (2 ** self.mode - 1.0) - 1.0) + y_coarse = coarse[:, 1:] + mels = torch.FloatTensor(mels) + return x_input, mels, y_coarse diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py new file mode 100644 index 00000000..f771175c --- /dev/null +++ b/TTS/vocoder/models/wavernn.py @@ -0,0 +1,496 @@ +import sys +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import time + +# fix this +from TTS.utils.audio import AudioProcessor as ap +from TTS.vocoder.utils.distribution import ( + sample_from_gaussian, + sample_from_discretized_mix_logistic, +) + + +def stream(string, variables): + sys.stdout.write(f"\r{string}" % variables) + + +class ResBlock(nn.Module): + def __init__(self, dims): + super().__init__() + self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.batch_norm1 = nn.BatchNorm1d(dims) + self.batch_norm2 = nn.BatchNorm1d(dims) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.batch_norm1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.batch_norm2(x) + return x + residual + + +class MelResNet(nn.Module): + def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad): + super().__init__() + k_size = pad * 2 + 1 + self.conv_in = nn.Conv1d( + in_dims, compute_dims, kernel_size=k_size, bias=False) + self.batch_norm = nn.BatchNorm1d(compute_dims) + self.layers = nn.ModuleList() + for _ in range(num_res_blocks): + self.layers.append(ResBlock(compute_dims)) + self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) + + def forward(self, x): + x = self.conv_in(x) + x = self.batch_norm(x) + x = F.relu(x) + for f in self.layers: + x = f(x) + x = self.conv_out(x) + return x + + +class Stretch2d(nn.Module): + def __init__(self, x_scale, y_scale): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + + def forward(self, x): + b, c, h, w = x.size() + x = x.unsqueeze(-1).unsqueeze(3) + x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) + return x.view(b, c, h * self.y_scale, w * self.x_scale) + + +class UpsampleNetwork(nn.Module): + def __init__( + self, + feat_dims, + upsample_scales, + compute_dims, + num_res_blocks, + res_out_dims, + pad, + use_aux_net, + ): + super().__init__() + self.total_scale = np.cumproduct(upsample_scales)[-1] + self.indent = pad * self.total_scale + self.use_aux_net = use_aux_net + if use_aux_net: + self.resnet = MelResNet( + num_res_blocks, feat_dims, compute_dims, res_out_dims, pad + ) + self.resnet_stretch = Stretch2d(self.total_scale, 1) + self.up_layers = nn.ModuleList() + for scale in upsample_scales: + k_size = (1, scale * 2 + 1) + padding = (0, scale) + stretch = Stretch2d(scale, 1) + conv = nn.Conv2d(1, 1, kernel_size=k_size, + padding=padding, bias=False) + conv.weight.data.fill_(1.0 / k_size[1]) + self.up_layers.append(stretch) + self.up_layers.append(conv) + + def forward(self, m): + if self.use_aux_net: + aux = self.resnet(m).unsqueeze(1) + aux = self.resnet_stretch(aux) + aux = aux.squeeze(1) + aux = aux.transpose(1, 2) + else: + aux = None + m = m.unsqueeze(1) + for f in self.up_layers: + m = f(m) + m = m.squeeze(1)[:, :, self.indent: -self.indent] + return m.transpose(1, 2), aux + + +class Upsample(nn.Module): + def __init__( + self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net + ): + super().__init__() + self.scale = scale + self.pad = pad + self.indent = pad * scale + self.use_aux_net = use_aux_net + self.resnet = MelResNet(num_res_blocks, feat_dims, + compute_dims, res_out_dims, pad) + + def forward(self, m): + if self.use_aux_net: + aux = self.resnet(m) + aux = torch.nn.functional.interpolate( + aux, scale_factor=self.scale, mode="linear", align_corners=True + ) + aux = aux.transpose(1, 2) + else: + aux = None + m = torch.nn.functional.interpolate( + m, scale_factor=self.scale, mode="linear", align_corners=True + ) + m = m[:, :, self.indent: -self.indent] + m = m * 0.045 # empirically found + + return m.transpose(1, 2), aux + + +class WaveRNN(nn.Module): + def __init__(self, + rnn_dims, + fc_dims, + mode, + mulaw, + pad, + use_aux_net, + use_upsample_net, + upsample_factors, + feat_dims, + compute_dims, + res_out_dims, + num_res_blocks, + hop_length, + sample_rate, + ): + super().__init__() + self.mode = mode + self.mulaw = mulaw + self.pad = pad + self.use_upsample_net = use_upsample_net + self.use_aux_net = use_aux_net + if isinstance(self.mode, int): + self.n_classes = 2 ** self.mode + elif self.mode == "mold": + self.n_classes = 3 * 10 + elif self.mode == "gauss": + self.n_classes = 2 + else: + raise RuntimeError("Unknown model mode value - ", self.mode) + + self.rnn_dims = rnn_dims + self.aux_dims = res_out_dims // 4 + self.hop_length = hop_length + self.sample_rate = sample_rate + + if self.use_upsample_net: + assert ( + np.cumproduct(upsample_factors)[-1] == self.hop_length + ), " [!] upsample scales needs to be equal to hop_length" + self.upsample = UpsampleNetwork( + feat_dims, + upsample_factors, + compute_dims, + num_res_blocks, + res_out_dims, + pad, + use_aux_net, + ) + else: + self.upsample = Upsample( + hop_length, + pad, + num_res_blocks, + feat_dims, + compute_dims, + res_out_dims, + use_aux_net, + ) + if self.use_aux_net: + self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims) + self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, + rnn_dims, batch_first=True) + self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) + self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) + self.fc3 = nn.Linear(fc_dims, self.n_classes) + else: + self.I = nn.Linear(feat_dims + 1, rnn_dims) + self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) + self.fc1 = nn.Linear(rnn_dims, fc_dims) + self.fc2 = nn.Linear(fc_dims, fc_dims) + self.fc3 = nn.Linear(fc_dims, self.n_classes) + + def forward(self, x, mels): + bsize = x.size(0) + h1 = torch.zeros(1, bsize, self.rnn_dims).to(x.device) + h2 = torch.zeros(1, bsize, self.rnn_dims).to(x.device) + mels, aux = self.upsample(mels) + + if self.use_aux_net: + aux_idx = [self.aux_dims * i for i in range(5)] + a1 = aux[:, :, aux_idx[0]: aux_idx[1]] + a2 = aux[:, :, aux_idx[1]: aux_idx[2]] + a3 = aux[:, :, aux_idx[2]: aux_idx[3]] + a4 = aux[:, :, aux_idx[3]: aux_idx[4]] + + x = ( + torch.cat([x.unsqueeze(-1), mels, a1], dim=2) + if self.use_aux_net + else torch.cat([x.unsqueeze(-1), mels], dim=2) + ) + x = self.I(x) + res = x + self.rnn1.flatten_parameters() + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=2) if self.use_aux_net else x + self.rnn2.flatten_parameters() + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=2) if self.use_aux_net else x + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4], dim=2) if self.use_aux_net else x + x = F.relu(self.fc2(x)) + return self.fc3(x) + + def generate(self, mels, batched, target, overlap, use_cuda=False): + + self.eval() + device = 'cuda' if use_cuda else 'cpu' + output = [] + start = time.time() + rnn1 = self.get_gru_cell(self.rnn1) + rnn2 = self.get_gru_cell(self.rnn2) + + with torch.no_grad(): + mels = torch.FloatTensor(mels).unsqueeze(0).to(device) + #mels = torch.FloatTensor(mels).cuda().unsqueeze(0) + wave_len = (mels.size(-1) - 1) * self.hop_length + mels = self.pad_tensor(mels.transpose( + 1, 2), pad=self.pad, side="both") + mels, aux = self.upsample(mels.transpose(1, 2)) + + if batched: + mels = self.fold_with_overlap(mels, target, overlap) + if aux is not None: + aux = self.fold_with_overlap(aux, target, overlap) + + b_size, seq_len, _ = mels.size() + + h1 = torch.zeros(b_size, self.rnn_dims).to(device) + h2 = torch.zeros(b_size, self.rnn_dims).to(device) + x = torch.zeros(b_size, 1).to(device) + + if self.use_aux_net: + d = self.aux_dims + aux_split = [aux[:, :, d * i: d * (i + 1)] for i in range(4)] + + for i in range(seq_len): + + m_t = mels[:, i, :] + + if self.use_aux_net: + a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) + + x = ( + torch.cat([x, m_t, a1_t], dim=1) + if self.use_aux_net + else torch.cat([x, m_t], dim=1) + ) + x = self.I(x) + h1 = rnn1(x, h1) + + x = x + h1 + inp = torch.cat([x, a2_t], dim=1) if self.use_aux_net else x + h2 = rnn2(inp, h2) + + x = x + h2 + x = torch.cat([x, a3_t], dim=1) if self.use_aux_net else x + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4_t], dim=1) if self.use_aux_net else x + x = F.relu(self.fc2(x)) + + logits = self.fc3(x) + + if self.mode == "mold": + sample = sample_from_discretized_mix_logistic( + logits.unsqueeze(0).transpose(1, 2) + ) + output.append(sample.view(-1)) + x = sample.transpose(0, 1).to(device) + elif self.mode == "gauss": + sample = sample_from_gaussian( + logits.unsqueeze(0).transpose(1, 2)) + output.append(sample.view(-1)) + x = sample.transpose(0, 1).to(device) + elif isinstance(self.mode, int): + posterior = F.softmax(logits, dim=1) + distrib = torch.distributions.Categorical(posterior) + + sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0 + output.append(sample) + x = sample.unsqueeze(-1) + else: + raise RuntimeError( + "Unknown model mode value - ", self.mode) + + if i % 100 == 0: + self.gen_display(i, seq_len, b_size, start) + + output = torch.stack(output).transpose(0, 1) + output = output.cpu().numpy() + output = output.astype(np.float64) + + if batched: + output = self.xfade_and_unfold(output, target, overlap) + else: + output = output[0] + + if self.mulaw and isinstance(self.mode, int): + output = ap.mulaw_decode(output, self.mode) + + # Fade-out at the end to avoid signal cutting out suddenly + fade_out = np.linspace(1, 0, 20 * self.hop_length) + output = output[:wave_len] + + if wave_len > len(fade_out): + output[-20 * self.hop_length:] *= fade_out + + self.train() + return output + + def gen_display(self, i, seq_len, b_size, start): + gen_rate = (i + 1) / (time.time() - start) * b_size / 1000 + realtime_ratio = gen_rate * 1000 / self.sample_rate + stream( + "%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ", + (i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio), + ) + + def fold_with_overlap(self, x, target, overlap): + """Fold the tensor with overlap for quick batched inference. + Overlap will be used for crossfading in xfade_and_unfold() + Args: + x (tensor) : Upsampled conditioning features. + shape=(1, timesteps, features) + target (int) : Target timesteps for each index of batch + overlap (int) : Timesteps for both xfade and rnn warmup + Return: + (tensor) : shape=(num_folds, target + 2 * overlap, features) + Details: + x = [[h1, h2, ... hn]] + Where each h is a vector of conditioning features + Eg: target=2, overlap=1 with x.size(1)=10 + folded = [[h1, h2, h3, h4], + [h4, h5, h6, h7], + [h7, h8, h9, h10]] + """ + + _, total_len, features = x.size() + + # Calculate variables needed + num_folds = (total_len - overlap) // (target + overlap) + extended_len = num_folds * (overlap + target) + overlap + remaining = total_len - extended_len + + # Pad if some time steps poking out + if remaining != 0: + num_folds += 1 + padding = target + 2 * overlap - remaining + x = self.pad_tensor(x, padding, side="after") + + folded = torch.zeros(num_folds, target + 2 * + overlap, features).to(x.device) + + # Get the values for the folded tensor + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + folded[i] = x[:, start:end, :] + + return folded + + @staticmethod + def get_gru_cell(gru): + gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) + gru_cell.weight_hh.data = gru.weight_hh_l0.data + gru_cell.weight_ih.data = gru.weight_ih_l0.data + gru_cell.bias_hh.data = gru.bias_hh_l0.data + gru_cell.bias_ih.data = gru.bias_ih_l0.data + return gru_cell + + @staticmethod + def pad_tensor(x, pad, side="both"): + # NB - this is just a quick method i need right now + # i.e., it won't generalise to other shapes/dims + b, t, c = x.size() + total = t + 2 * pad if side == "both" else t + pad + padded = torch.zeros(b, total, c).to(x.device) + if side in ("before", "both"): + padded[:, pad: pad + t, :] = x + elif side == "after": + padded[:, :t, :] = x + return padded + + @staticmethod + def xfade_and_unfold(y, target, overlap): + """Applies a crossfade and unfolds into a 1d array. + Args: + y (ndarry) : Batched sequences of audio samples + shape=(num_folds, target + 2 * overlap) + dtype=np.float64 + overlap (int) : Timesteps for both xfade and rnn warmup + Return: + (ndarry) : audio samples in a 1d array + shape=(total_len) + dtype=np.float64 + Details: + y = [[seq1], + [seq2], + [seq3]] + Apply a gain envelope at both ends of the sequences + y = [[seq1_in, seq1_target, seq1_out], + [seq2_in, seq2_target, seq2_out], + [seq3_in, seq3_target, seq3_out]] + Stagger and add up the groups of samples: + [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] + """ + + num_folds, length = y.shape + target = length - 2 * overlap + total_len = num_folds * (target + overlap) + overlap + + # Need some silence for the rnn warmup + silence_len = overlap // 2 + fade_len = overlap - silence_len + silence = np.zeros((silence_len), dtype=np.float64) + + # Equal power crossfade + t = np.linspace(-1, 1, fade_len, dtype=np.float64) + fade_in = np.sqrt(0.5 * (1 + t)) + fade_out = np.sqrt(0.5 * (1 - t)) + + # Concat the silence to the fades + fade_in = np.concatenate([silence, fade_in]) + fade_out = np.concatenate([fade_out, silence]) + + # Apply the gain to the overlap samples + y[:, :overlap] *= fade_in + y[:, -overlap:] *= fade_out + + unfolded = np.zeros((total_len), dtype=np.float64) + + # Loop to add up all the samples + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + unfolded[start:end] += y[i] + + return unfolded diff --git a/TTS/vocoder/utils/distribution.py b/TTS/vocoder/utils/distribution.py new file mode 100644 index 00000000..6aba5e34 --- /dev/null +++ b/TTS/vocoder/utils/distribution.py @@ -0,0 +1,168 @@ +import numpy as np +import math +import torch +from torch.distributions.normal import Normal +import torch.nn.functional as F + + +def gaussian_loss(y_hat, y, log_std_min=-7.0): + assert y_hat.dim() == 3 + assert y_hat.size(2) == 2 + mean = y_hat[:, :, :1] + log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) + # TODO: replace with pytorch dist + log_probs = -0.5 * ( + -math.log(2.0 * math.pi) + - 2.0 * log_std + - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std)) + ) + return log_probs.squeeze().mean() + + +def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0): + assert y_hat.size(2) == 2 + mean = y_hat[:, :, :1] + log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) + dist = Normal( + mean, + torch.exp(log_std), + ) + sample = dist.sample() + sample = torch.clamp(torch.clamp( + sample, min=-scale_factor), max=scale_factor) + del dist + return sample + + +def log_sum_exp(x): + """ numerically stable log_sum_exp implementation that prevents overflow """ + # TF ordering + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) + + +# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py +def discretized_mix_logistic_loss( + y_hat, y, num_classes=65536, log_scale_min=None, reduce=True +): + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + y_hat = y_hat.permute(0, 2, 1) + assert y_hat.dim() == 3 + assert y_hat.size(1) % 3 == 0 + nr_mix = y_hat.size(1) // 3 + + # (B x T x C) + y_hat = y_hat.transpose(1, 2) + + # unpack parameters. (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix: 2 * nr_mix] + log_scales = torch.clamp( + y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=log_scale_min) + + # B x T x 1 -> B x T x num_mixtures + y = y.expand_as(means) + + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1)) + cdf_min = torch.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(F.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - F.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + # (not actually used in our code) + log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) + + # tf equivalent + + # log_probs = tf.where(x < -0.999, log_cdf_plus, + # tf.where(x > 0.999, log_one_minus_cdf_min, + # tf.where(cdf_delta > 1e-5, + # tf.log(tf.maximum(cdf_delta, 1e-12)), + # log_pdf_mid - np.log(127.5)))) + + # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value + # for num_classes=65536 case? 1e-7? not sure.. + inner_inner_cond = (cdf_delta > 1e-5).float() + + inner_inner_out = inner_inner_cond * torch.log( + torch.clamp(cdf_delta, min=1e-12) + ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) + inner_cond = (y > 0.999).float() + inner_out = ( + inner_cond * log_one_minus_cdf_min + + (1.0 - inner_cond) * inner_inner_out + ) + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out + + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + return -torch.mean(log_sum_exp(log_probs)) + return -log_sum_exp(log_probs).unsqueeze(-1) + + +def sample_from_discretized_mix_logistic(y, log_scale_min=None): + """ + Sample from discretized mixture of logistic distributions + Args: + y (Tensor): B x C x T + log_scale_min (float): Log scale minimum value + Returns: + Tensor: sample in range of [-1, 1]. + """ + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + assert y.size(1) % 3 == 0 + nr_mix = y.size(1) // 3 + + # B x T x C + y = y.transpose(1, 2) + logit_probs = y[:, :, :nr_mix] + + # sample mixture indicator from softmax + temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.data - torch.log(-torch.log(temp)) + _, argmax = temp.max(dim=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = to_one_hot(argmax, nr_mix) + # select logistic parameters + means = torch.sum(y[:, :, nr_mix: 2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.clamp( + torch.sum(y[:, :, 2 * nr_mix: 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min + ) + # sample from logistic & clip to interval + # we don't actually round to the nearest 8bit value when sampling + u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) + x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u)) + + x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0) + + return x + + +def to_one_hot(tensor, n, fill_with=1.0): + # we perform one hot encore with respect to the last axis + one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() + if tensor.is_cuda: + one_hot = one_hot.cuda() + one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) + return one_hot diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 89dc68fb..f9fbba52 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -42,6 +42,29 @@ def to_camel(text): return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) +def setup_wavernn(c): + print(" > Model: WaveRNN") + MyModel = importlib.import_module("TTS.vocoder.models.wavernn") + MyModel = getattr(MyModel, "WaveRNN") + model = MyModel( + rnn_dims=c.wavernn_model_params['rnn_dims'], + fc_dims=c.wavernn_model_params['fc_dims'], + mode=c.mode, + mulaw=c.mulaw, + pad=c.padding, + use_aux_net=c.wavernn_model_params['use_aux_net'], + use_upsample_net=c.wavernn_model_params['use_upsample_net'], + upsample_factors=c.wavernn_model_params['upsample_factors'], + feat_dims=c.audio['num_mels'], + compute_dims=c.wavernn_model_params['compute_dims'], + res_out_dims=c.wavernn_model_params['res_out_dims'], + num_res_blocks=c.wavernn_model_params['num_res_blocks'], + hop_length=c.audio["hop_length"], + sample_rate=c.audio["sample_rate"], + ) + return model + + def setup_generator(c): print(" > Generator Model: {}".format(c.generator_model)) MyModel = importlib.import_module('TTS.vocoder.models.' + diff --git a/tests/inputs/test_vocoder_wavernn_config.json b/tests/inputs/test_vocoder_wavernn_config.json new file mode 100644 index 00000000..28c0f059 --- /dev/null +++ b/tests/inputs/test_vocoder_wavernn_config.json @@ -0,0 +1,94 @@ +{ + "run_name": "wavernn_test", + "run_description": "wavernn_test training", + + // AUDIO PARAMETERS + "audio":{ + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + + // Audio processing parameters + "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "ref_level_db": 0, // reference level db, theoretically 20db is the sound of air. + + // Silence trimming + "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60, // threshold for timming silence. Set this according to your dataset. + + // MelSpectrogram parameters + "num_mels": 80, // size of the mel spec frame. + "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! + "spec_gain": 20.0, // scaler value appplied after log transform of spectrogram. + + // Normalization parameters + "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. + "min_level_db": -100, // lower bound for normalization + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored + }, + + // Generating / Synthesizing + "batched": true, + "target_samples": 11000, // target number of samples to be generated in each batch entry + "overlap_samples": 550, // number of samples for crossfading between batches + + // DISTRIBUTED TRAINING + // "distributed":{ + // "backend": "nccl", + // "url": "tcp:\/\/localhost:54321" + // }, + + // MODEL PARAMETERS + "use_aux_net": true, + "use_upsample_net": true, + "upsample_factors": [4, 8, 8], // this needs to correctly factorise hop_length + "seq_len": 1280, // has to be devideable by hop_length + "mode": "mold", // mold [string], gauss [string], bits [int] + "mulaw": false, // apply mulaw if mode is bits + "padding": 2, // pad the input for resnet to see wider input length + + // DATASET + //"use_gta": true, // use computed gta features from the tts model + "data_path": "tests/data/ljspeech/wavs/", // path containing training wav files + "feature_path": null, // path containing computed features from wav files if null compute them + + // TRAINING + "batch_size": 4, // Batch size for training. Lower values than 32 might cause hard to learn attention. + "epochs": 1, // total number of epochs to train. + + // VALIDATION + "run_eval": true, + "test_every_epochs": 10, // Test after set number of epochs (Test every 20 epochs for example) + + // OPTIMIZER + "grad_clip": 4, // apply gradient clipping if > 0 + "lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate + "lr_scheduler_params": { + "gamma": 0.5, + "milestones": [200000, 400000, 600000] + }, + "lr": 1e-4, // initial learning rate + + // TENSORBOARD and LOGGING + "print_step": 25, // Number of steps to log traning on console. + "print_eval": false, // If True, it prints loss values for each step in eval run. + "save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + + // DATA LOADING + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 4, // number of evaluation data loader processes. + "eval_split_size": 10, // number of samples for testing + + // PATHS + "output_path": "tests/train_outputs/" +} + diff --git a/tests/test_vocoder_datasets.py b/tests/test_vocoder_gan_datasets.py similarity index 100% rename from tests/test_vocoder_datasets.py rename to tests/test_vocoder_gan_datasets.py diff --git a/tests/test_vocoder_train.sh b/tests/test_vocoder_gan_train.sh similarity index 57% rename from tests/test_vocoder_train.sh rename to tests/test_vocoder_gan_train.sh index fa99b4bd..75773cc3 100755 --- a/tests/test_vocoder_train.sh +++ b/tests/test_vocoder_gan_train.sh @@ -5,11 +5,11 @@ echo "$BASEDIR" # create run dir mkdir $BASEDIR/train_outputs # run training -CUDA_VISIBLE_DEVICES="" python TTS/bin/train_vocoder.py --config_path $BASEDIR/inputs/test_vocoder_multiband_melgan_config.json +CUDA_VISIBLE_DEVICES="" python TTS/bin/train_gan_vocoder.py --config_path $BASEDIR/inputs/test_vocoder_multiband_melgan_config.json # find the training folder LATEST_FOLDER=$(ls $BASEDIR/train_outputs/| sort | tail -1) echo $LATEST_FOLDER # continue the previous training -CUDA_VISIBLE_DEVICES="" python TTS/bin/train_vocoder.py --continue_path $BASEDIR/train_outputs/$LATEST_FOLDER +CUDA_VISIBLE_DEVICES="" python TTS/bin/train_gan_vocoder.py --continue_path $BASEDIR/train_outputs/$LATEST_FOLDER # remove all the outputs rm -rf $BASEDIR/train_outputs/$LATEST_FOLDER diff --git a/tests/test_vocoder_wavernn.py b/tests/test_vocoder_wavernn.py new file mode 100644 index 00000000..ccd71c56 --- /dev/null +++ b/tests/test_vocoder_wavernn.py @@ -0,0 +1,31 @@ +import numpy as np +import torch +import random +from TTS.vocoder.models.wavernn import WaveRNN + + +def test_wavernn(): + model = WaveRNN( + rnn_dims=512, + fc_dims=512, + mode=10, + mulaw=False, + pad=2, + use_aux_net=True, + use_upsample_net=True, + upsample_factors=[4, 8, 8], + feat_dims=80, + compute_dims=128, + res_out_dims=128, + num_res_blocks=10, + hop_length=256, + sample_rate=22050, + ) + dummy_x = torch.rand((2, 1280)) + dummy_m = torch.rand((2, 80, 9)) + y_size = random.randrange(20, 60) + dummy_y = torch.rand((80, y_size)) + output = model(dummy_x, dummy_m) + assert np.all(output.shape == (2, 1280, 4 * 256)), output.shape + output = model.generate(dummy_y, True, 5500, 550, False) + assert np.all(output.shape == (256 * (y_size - 1),)) diff --git a/tests/test_vocoder_wavernn_datasets.py b/tests/test_vocoder_wavernn_datasets.py new file mode 100644 index 00000000..a95e247a --- /dev/null +++ b/tests/test_vocoder_wavernn_datasets.py @@ -0,0 +1,92 @@ +import os +import shutil + +import numpy as np +from tests import get_tests_path, get_tests_input_path, get_tests_output_path +from torch.utils.data import DataLoader + +from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_config +from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset +from TTS.vocoder.datasets.preprocess import load_wav_feat_data, preprocess_wav_files + +file_path = os.path.dirname(os.path.realpath(__file__)) +OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/") +os.makedirs(OUTPATH, exist_ok=True) + +C = load_config(os.path.join(get_tests_input_path(), + "test_vocoder_wavernn_config.json")) + +test_data_path = os.path.join(get_tests_path(), "data/ljspeech/") +test_mel_feat_path = os.path.join(test_data_path, "mel") +test_quant_feat_path = os.path.join(test_data_path, "quant") +ok_ljspeech = os.path.exists(test_data_path) + + +def wavernn_dataset_case(batch_size, seq_len, hop_len, pad, mode, mulaw, num_workers): + """ run dataloader with given parameters and check conditions """ + ap = AudioProcessor(**C.audio) + + C.batch_size = batch_size + C.mode = mode + C.seq_len = seq_len + C.data_path = test_data_path + + preprocess_wav_files(test_data_path, C, ap) + _, train_items = load_wav_feat_data( + test_data_path, test_mel_feat_path, 5) + + dataset = WaveRNNDataset(ap=ap, + items=train_items, + seq_len=seq_len, + hop_len=hop_len, + pad=pad, + mode=mode, + mulaw=mulaw + ) + # sampler = DistributedSampler(dataset) if num_gpus > 1 else None + loader = DataLoader(dataset, + shuffle=True, + collate_fn=dataset.collate, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + ) + + max_iter = 10 + count_iter = 0 + + try: + for data in loader: + x_input, mels, _ = data + expected_feat_shape = (ap.num_mels, + (x_input.shape[-1] // hop_len) + (pad * 2)) + assert np.all( + mels.shape[1:] == expected_feat_shape), f" [!] {mels.shape} vs {expected_feat_shape}" + + assert (mels.shape[2] - pad * 2) * hop_len == x_input.shape[1] + count_iter += 1 + if count_iter == max_iter: + break + # except AssertionError: + # shutil.rmtree(test_mel_feat_path) + # shutil.rmtree(test_quant_feat_path) + finally: + shutil.rmtree(test_mel_feat_path) + shutil.rmtree(test_quant_feat_path) + + +def test_parametrized_wavernn_dataset(): + ''' test dataloader with different parameters ''' + params = [ + [16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 10, True, 0], + [16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, "mold", False, 4], + [1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 9, False, 0], + [1, C.audio['hop_length'], C.audio['hop_length'], 2, 10, True, 0], + [1, C.audio['hop_length'], C.audio['hop_length'], 2, "mold", False, 0], + [1, C.audio['hop_length'] * 5, C.audio['hop_length'], 4, 10, False, 2], + [1, C.audio['hop_length'] * 5, C.audio['hop_length'], 2, "mold", False, 0], + ] + for param in params: + print(param) + wavernn_dataset_case(*param) diff --git a/tests/test_vocoder_wavernn_train.sh b/tests/test_vocoder_wavernn_train.sh new file mode 100755 index 00000000..f2e32116 --- /dev/null +++ b/tests/test_vocoder_wavernn_train.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +BASEDIR=$(dirname "$0") +echo "$BASEDIR" +# create run dir +mkdir $BASEDIR/train_outputs +# run training +CUDA_VISIBLE_DEVICES="" python TTS/bin/train_wavernn_vocoder.py --config_path $BASEDIR/inputs/test_vocoder_wavernn_config.json +# find the training folder +LATEST_FOLDER=$(ls $BASEDIR/train_outputs/| sort | tail -1) +echo $LATEST_FOLDER +# continue the previous training +CUDA_VISIBLE_DEVICES="" python TTS/bin/train_wavernn_vocoder.py --continue_path $BASEDIR/train_outputs/$LATEST_FOLDER +# remove all the outputs +rm -rf $BASEDIR/train_outputs/$LATEST_FOLDER \ No newline at end of file