diff --git a/.travis/script b/.travis/script
index 0c24a221..0860f9cf 100755
--- a/.travis/script
+++ b/.travis/script
@@ -17,5 +17,6 @@ fi
if [[ "$TEST_SUITE" == "testscripts" ]]; then
# test model training scripts
./tests/test_tts_train.sh
- ./tests/test_vocoder_train.sh
+ ./tests/test_vocoder_gan_train.sh
+ ./tests/test_vocoder_wavernn_train.sh
fi
diff --git a/README.md b/README.md
index 31348f28..53ec1aad 100644
--- a/README.md
+++ b/README.md
@@ -11,19 +11,22 @@
+Mozilla TTS is a deep learning based Text2Speech project, low in cost and high in quality.
+
This project is a part of [Mozilla Common Voice](https://voice.mozilla.org/en).
-Mozilla TTS aims a deep learning based Text2Speech engine, low in cost and high in quality.
+English Voice Samples: https://erogol.github.io/ddc-samples/
-You can check some of synthesized voice samples from [here](https://erogol.github.io/ddc-samples/).
+TTS training recipes: https://github.com/erogol/TTS_recipes
-If you are new, you can also find [here](http://www.erogol.com/text-speech-deep-learning-architectures/) a brief post about some of TTS architectures and [here](https://github.com/erogol/TTS-papers) list of up-to-date research papers.
+TTS paper collection: https://github.com/erogol/TTS-papers
[](https://sourcerer.io/fame/erogol/erogol/TTS/links/0)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/1)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/2)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/3)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/4)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/5)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/6)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/7)
## TTS Performance
-

+
+"Mozilla*" and "Judy*" are our models.
[Details...](https://github.com/mozilla/TTS/wiki/Mean-Opinion-Score-Results)
## Provided Models and Methods
diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py
index 1c6ef94d..ca089d3e 100755
--- a/TTS/bin/compute_statistics.py
+++ b/TTS/bin/compute_statistics.py
@@ -11,6 +11,7 @@ from TTS.tts.datasets.preprocess import load_meta_data
from TTS.utils.io import load_config
from TTS.utils.audio import AudioProcessor
+
def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(
diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_gan_vocoder.py
similarity index 98%
rename from TTS/bin/train_vocoder.py
rename to TTS/bin/train_gan_vocoder.py
index b51a55a3..12edf048 100644
--- a/TTS/bin/train_vocoder.py
+++ b/TTS/bin/train_gan_vocoder.py
@@ -326,7 +326,6 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
y_hat = model_G.pqmf_synthesis(y_hat)
y_G_sub = model_G.pqmf_analysis(y_G)
-
scores_fake, feats_fake, feats_real = None, None, None
if global_step > c.steps_to_start_discriminator:
@@ -403,7 +402,6 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
else:
loss_dict[key] = value.item()
-
step_time = time.time() - start_time
epoch_time += step_time
@@ -443,7 +441,8 @@ def main(args): # pylint: disable=redefined-outer-name
print(f" > Loading wavs from: {c.data_path}")
if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}")
- eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
+ eval_data, train_data = load_wav_feat_data(
+ c.data_path, c.feature_path, c.eval_split_size)
else:
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
@@ -470,10 +469,12 @@ def main(args): # pylint: disable=redefined-outer-name
scheduler_disc = None
if 'lr_scheduler_gen' in c:
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
- scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params)
+ scheduler_gen = scheduler_gen(
+ optimizer_gen, **c.lr_scheduler_gen_params)
if 'lr_scheduler_disc' in c:
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
- scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params)
+ scheduler_disc = scheduler_disc(
+ optimizer_disc, **c.lr_scheduler_disc_params)
# setup criterion
criterion_gen = GeneratorLoss(c)
@@ -572,8 +573,7 @@ if __name__ == '__main__':
parser.add_argument(
'--continue_path',
type=str,
- help=
- 'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
+ help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
diff --git a/TTS/bin/train_wavernn_vocoder.py b/TTS/bin/train_wavernn_vocoder.py
new file mode 100644
index 00000000..61664a65
--- /dev/null
+++ b/TTS/bin/train_wavernn_vocoder.py
@@ -0,0 +1,515 @@
+import argparse
+import os
+import sys
+import traceback
+import time
+import glob
+import random
+
+import torch
+from torch.utils.data import DataLoader
+
+# from torch.utils.data.distributed import DistributedSampler
+
+from TTS.tts.utils.visual import plot_spectrogram
+from TTS.utils.audio import AudioProcessor
+from TTS.utils.radam import RAdam
+from TTS.utils.io import copy_config_file, load_config
+from TTS.utils.training import setup_torch_training_env
+from TTS.utils.console_logger import ConsoleLogger
+from TTS.utils.tensorboard_logger import TensorboardLogger
+from TTS.utils.generic_utils import (
+ KeepAverage,
+ count_parameters,
+ create_experiment_folder,
+ get_git_branch,
+ remove_experiment_folder,
+ set_init_dict,
+)
+from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
+from TTS.vocoder.datasets.preprocess import (
+ load_wav_data,
+ load_wav_feat_data
+)
+from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
+from TTS.vocoder.utils.generic_utils import setup_wavernn
+from TTS.vocoder.utils.io import save_best_model, save_checkpoint
+
+
+use_cuda, num_gpus = setup_torch_training_env(True, True)
+
+
+def setup_loader(ap, is_val=False, verbose=False):
+ if is_val and not c.run_eval:
+ loader = None
+ else:
+ dataset = WaveRNNDataset(ap=ap,
+ items=eval_data if is_val else train_data,
+ seq_len=c.seq_len,
+ hop_len=ap.hop_length,
+ pad=c.padding,
+ mode=c.mode,
+ mulaw=c.mulaw,
+ is_training=not is_val,
+ verbose=verbose,
+ )
+ # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
+ loader = DataLoader(dataset,
+ shuffle=True,
+ collate_fn=dataset.collate,
+ batch_size=c.batch_size,
+ num_workers=c.num_val_loader_workers
+ if is_val
+ else c.num_loader_workers,
+ pin_memory=True,
+ )
+ return loader
+
+
+def format_data(data):
+ # setup input data
+ x_input = data[0]
+ mels = data[1]
+ y_coarse = data[2]
+
+ # dispatch data to GPU
+ if use_cuda:
+ x_input = x_input.cuda(non_blocking=True)
+ mels = mels.cuda(non_blocking=True)
+ y_coarse = y_coarse.cuda(non_blocking=True)
+
+ return x_input, mels, y_coarse
+
+
+def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
+ # create train loader
+ data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
+ model.train()
+ epoch_time = 0
+ keep_avg = KeepAverage()
+ if use_cuda:
+ batch_n_iter = int(len(data_loader.dataset) /
+ (c.batch_size * num_gpus))
+ else:
+ batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
+ end_time = time.time()
+ c_logger.print_train_start()
+ # train loop
+ print(" > Training", flush=True)
+ for num_iter, data in enumerate(data_loader):
+ start_time = time.time()
+ x_input, mels, y_coarse = format_data(data)
+ loader_time = time.time() - end_time
+ global_step += 1
+
+ y_hat = model(x_input, mels)
+
+ if isinstance(model.mode, int):
+ y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
+ else:
+ y_coarse = y_coarse.float()
+ y_coarse = y_coarse.unsqueeze(-1)
+
+ # compute losses
+ loss = criterion(y_hat, y_coarse)
+ if loss.item() is None:
+ raise RuntimeError(" [!] None loss. Exiting ...")
+ optimizer.zero_grad()
+ loss.backward()
+ if c.grad_clip > 0:
+ torch.nn.utils.clip_grad_norm_(
+ model.parameters(), c.grad_clip)
+ optimizer.step()
+
+ if scheduler is not None:
+ scheduler.step()
+
+ # get the current learning rate
+ cur_lr = list(optimizer.param_groups)[0]["lr"]
+
+ step_time = time.time() - start_time
+ epoch_time += step_time
+
+ update_train_values = dict()
+ loss_dict = dict()
+ loss_dict["model_loss"] = loss.item()
+ for key, value in loss_dict.items():
+ update_train_values["avg_" + key] = value
+ update_train_values["avg_loader_time"] = loader_time
+ update_train_values["avg_step_time"] = step_time
+ keep_avg.update_values(update_train_values)
+
+ # print training stats
+ if global_step % c.print_step == 0:
+ log_dict = {"step_time": [step_time, 2],
+ "loader_time": [loader_time, 4],
+ "current_lr": cur_lr,
+ }
+ c_logger.print_train_step(batch_n_iter,
+ num_iter,
+ global_step,
+ log_dict,
+ loss_dict,
+ keep_avg.avg_values,
+ )
+
+ # plot step stats
+ if global_step % 10 == 0:
+ iter_stats = {"lr": cur_lr, "step_time": step_time}
+ iter_stats.update(loss_dict)
+ tb_logger.tb_train_iter_stats(global_step, iter_stats)
+
+ # save checkpoint
+ if global_step % c.save_step == 0:
+ if c.checkpoint:
+ # save model
+ save_checkpoint(model,
+ optimizer,
+ scheduler,
+ None,
+ None,
+ None,
+ global_step,
+ epoch,
+ OUT_PATH,
+ model_losses=loss_dict,
+ )
+
+ # synthesize a full voice
+ rand_idx = random.randrange(0, len(train_data))
+ wav_path = train_data[rand_idx] if not isinstance(
+ train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
+ wav = ap.load_wav(wav_path)
+ ground_mel = ap.melspectrogram(wav)
+ sample_wav = model.generate(ground_mel,
+ c.batched,
+ c.target_samples,
+ c.overlap_samples,
+ use_cuda
+ )
+ predict_mel = ap.melspectrogram(sample_wav)
+
+ # compute spectrograms
+ figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
+ "train/prediction": plot_spectrogram(predict_mel.T)
+ }
+ tb_logger.tb_train_figures(global_step, figures)
+
+ # Sample audio
+ tb_logger.tb_train_audios(
+ global_step, {
+ "train/audio": sample_wav}, c.audio["sample_rate"]
+ )
+ end_time = time.time()
+
+ # print epoch stats
+ c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
+
+ # Plot Training Epoch Stats
+ epoch_stats = {"epoch_time": epoch_time}
+ epoch_stats.update(keep_avg.avg_values)
+ tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
+ # TODO: plot model stats
+ # if c.tb_model_param_stats:
+ # tb_logger.tb_model_weights(model, global_step)
+ return keep_avg.avg_values, global_step
+
+
+@torch.no_grad()
+def evaluate(model, criterion, ap, global_step, epoch):
+ # create train loader
+ data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
+ model.eval()
+ epoch_time = 0
+ keep_avg = KeepAverage()
+ end_time = time.time()
+ c_logger.print_eval_start()
+ with torch.no_grad():
+ for num_iter, data in enumerate(data_loader):
+ start_time = time.time()
+ # format data
+ x_input, mels, y_coarse = format_data(data)
+ loader_time = time.time() - end_time
+ global_step += 1
+
+ y_hat = model(x_input, mels)
+ if isinstance(model.mode, int):
+ y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
+ else:
+ y_coarse = y_coarse.float()
+ y_coarse = y_coarse.unsqueeze(-1)
+ loss = criterion(y_hat, y_coarse)
+ # Compute avg loss
+ # if num_gpus > 1:
+ # loss = reduce_tensor(loss.data, num_gpus)
+ loss_dict = dict()
+ loss_dict["model_loss"] = loss.item()
+
+ step_time = time.time() - start_time
+ epoch_time += step_time
+
+ # update avg stats
+ update_eval_values = dict()
+ for key, value in loss_dict.items():
+ update_eval_values["avg_" + key] = value
+ update_eval_values["avg_loader_time"] = loader_time
+ update_eval_values["avg_step_time"] = step_time
+ keep_avg.update_values(update_eval_values)
+
+ # print eval stats
+ if c.print_eval:
+ c_logger.print_eval_step(
+ num_iter, loss_dict, keep_avg.avg_values)
+
+ if epoch % c.test_every_epochs == 0 and epoch != 0:
+ # synthesize a full voice
+ rand_idx = random.randrange(0, len(eval_data))
+ wav_path = eval_data[rand_idx] if not isinstance(
+ eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
+ wav = ap.load_wav(wav_path)
+ ground_mel = ap.melspectrogram(wav)
+ sample_wav = model.generate(ground_mel,
+ c.batched,
+ c.target_samples,
+ c.overlap_samples,
+ use_cuda
+ )
+ predict_mel = ap.melspectrogram(sample_wav)
+
+ # Sample audio
+ tb_logger.tb_eval_audios(
+ global_step, {
+ "eval/audio": sample_wav}, c.audio["sample_rate"]
+ )
+
+ # compute spectrograms
+ figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
+ "eval/prediction": plot_spectrogram(predict_mel.T)
+ }
+ tb_logger.tb_eval_figures(global_step, figures)
+
+ tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
+ return keep_avg.avg_values
+
+
+# FIXME: move args definition/parsing inside of main?
+def main(args): # pylint: disable=redefined-outer-name
+ # pylint: disable=global-variable-undefined
+ global train_data, eval_data
+
+ # setup audio processor
+ ap = AudioProcessor(**c.audio)
+
+ # print(f" > Loading wavs from: {c.data_path}")
+ # if c.feature_path is not None:
+ # print(f" > Loading features from: {c.feature_path}")
+ # eval_data, train_data = load_wav_feat_data(
+ # c.data_path, c.feature_path, c.eval_split_size
+ # )
+ # else:
+ # mel_feat_path = os.path.join(OUT_PATH, "mel")
+ # feat_data = find_feat_files(mel_feat_path)
+ # if feat_data:
+ # print(f" > Loading features from: {mel_feat_path}")
+ # eval_data, train_data = load_wav_feat_data(
+ # c.data_path, mel_feat_path, c.eval_split_size
+ # )
+ # else:
+ # print(" > No feature data found. Preprocessing...")
+ # # preprocessing feature data from given wav files
+ # preprocess_wav_files(OUT_PATH, CONFIG, ap)
+ # eval_data, train_data = load_wav_feat_data(
+ # c.data_path, mel_feat_path, c.eval_split_size
+ # )
+
+ print(f" > Loading wavs from: {c.data_path}")
+ if c.feature_path is not None:
+ print(f" > Loading features from: {c.feature_path}")
+ eval_data, train_data = load_wav_feat_data(
+ c.data_path, c.feature_path, c.eval_split_size)
+ else:
+ eval_data, train_data = load_wav_data(
+ c.data_path, c.eval_split_size)
+ # setup model
+ model_wavernn = setup_wavernn(c)
+
+ # define train functions
+ if c.mode == "mold":
+ criterion = discretized_mix_logistic_loss
+ elif c.mode == "gauss":
+ criterion = gaussian_loss
+ elif isinstance(c.mode, int):
+ criterion = torch.nn.CrossEntropyLoss()
+
+ if use_cuda:
+ model_wavernn.cuda()
+ if isinstance(c.mode, int):
+ criterion.cuda()
+
+ optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0)
+
+ scheduler = None
+ if "lr_scheduler" in c:
+ scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
+ scheduler = scheduler(optimizer, **c.lr_scheduler_params)
+ # slow start for the first 5 epochs
+ # lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1)
+ # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
+
+ # restore any checkpoint
+ if args.restore_path:
+ checkpoint = torch.load(args.restore_path, map_location="cpu")
+ try:
+ print(" > Restoring Model...")
+ model_wavernn.load_state_dict(checkpoint["model"])
+ print(" > Restoring Optimizer...")
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if "scheduler" in checkpoint:
+ print(" > Restoring Generator LR Scheduler...")
+ scheduler.load_state_dict(checkpoint["scheduler"])
+ scheduler.optimizer = optimizer
+ # TODO: fix resetting restored optimizer lr
+ # optimizer.load_state_dict(checkpoint["optimizer"])
+ except RuntimeError:
+ # retore only matching layers.
+ print(" > Partial model initialization...")
+ model_dict = model_wavernn.state_dict()
+ model_dict = set_init_dict(model_dict, checkpoint["model"], c)
+ model_wavernn.load_state_dict(model_dict)
+
+ print(" > Model restored from step %d" %
+ checkpoint["step"], flush=True)
+ args.restore_step = checkpoint["step"]
+ else:
+ args.restore_step = 0
+
+ # DISTRIBUTED
+ # if num_gpus > 1:
+ # model = apply_gradient_allreduce(model)
+
+ num_parameters = count_parameters(model_wavernn)
+ print(" > Model has {} parameters".format(num_parameters), flush=True)
+
+ if "best_loss" not in locals():
+ best_loss = float("inf")
+
+ global_step = args.restore_step
+ for epoch in range(0, c.epochs):
+ c_logger.print_epoch_start(epoch, c.epochs)
+ _, global_step = train(model_wavernn, optimizer,
+ criterion, scheduler, ap, global_step, epoch)
+ eval_avg_loss_dict = evaluate(
+ model_wavernn, criterion, ap, global_step, epoch)
+ c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
+ target_loss = eval_avg_loss_dict["avg_model_loss"]
+ best_loss = save_best_model(
+ target_loss,
+ best_loss,
+ model_wavernn,
+ optimizer,
+ scheduler,
+ None,
+ None,
+ None,
+ global_step,
+ epoch,
+ OUT_PATH,
+ model_losses=eval_avg_loss_dict,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--continue_path",
+ type=str,
+ help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
+ default="",
+ required="--config_path" not in sys.argv,
+ )
+ parser.add_argument(
+ "--restore_path",
+ type=str,
+ help="Model file to be restored. Use to finetune a model.",
+ default="",
+ )
+ parser.add_argument(
+ "--config_path",
+ type=str,
+ help="Path to config file for training.",
+ required="--continue_path" not in sys.argv,
+ )
+ parser.add_argument(
+ "--debug",
+ type=bool,
+ default=False,
+ help="Do not verify commit integrity to run training.",
+ )
+
+ # DISTRUBUTED
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=0,
+ help="DISTRIBUTED: process rank for distributed training.",
+ )
+ parser.add_argument(
+ "--group_id", type=str, default="", help="DISTRIBUTED: process group id."
+ )
+ args = parser.parse_args()
+
+ if args.continue_path != "":
+ args.output_path = args.continue_path
+ args.config_path = os.path.join(args.continue_path, "config.json")
+ list_of_files = glob.glob(
+ args.continue_path + "/*.pth.tar"
+ ) # * means all if need specific format then *.csv
+ latest_model_file = max(list_of_files, key=os.path.getctime)
+ args.restore_path = latest_model_file
+ print(f" > Training continues for {args.restore_path}")
+
+ # setup output paths and read configs
+ c = load_config(args.config_path)
+ # check_config(c)
+ _ = os.path.dirname(os.path.realpath(__file__))
+
+ OUT_PATH = args.continue_path
+ if args.continue_path == "":
+ OUT_PATH = create_experiment_folder(
+ c.output_path, c.run_name, args.debug
+ )
+
+ AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
+
+ c_logger = ConsoleLogger()
+
+ if args.rank == 0:
+ os.makedirs(AUDIO_PATH, exist_ok=True)
+ new_fields = {}
+ if args.restore_path:
+ new_fields["restore_path"] = args.restore_path
+ new_fields["github_branch"] = get_git_branch()
+ copy_config_file(
+ args.config_path, os.path.join(OUT_PATH, "c.json"), new_fields
+ )
+ os.chmod(AUDIO_PATH, 0o775)
+ os.chmod(OUT_PATH, 0o775)
+
+ LOG_DIR = OUT_PATH
+ tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
+
+ # write model desc to tensorboard
+ tb_logger.tb_add_text("model-description", c["run_description"], 0)
+
+ try:
+ main(args)
+ except KeyboardInterrupt:
+ remove_experiment_folder(OUT_PATH)
+ try:
+ sys.exit(0)
+ except SystemExit:
+ os._exit(0) # pylint: disable=protected-access
+ except Exception: # pylint: disable=broad-except
+ remove_experiment_folder(OUT_PATH)
+ traceback.print_exc()
+ sys.exit(1)
diff --git a/TTS/vocoder/configs/wavernn_config.json b/TTS/vocoder/configs/wavernn_config.json
new file mode 100644
index 00000000..9a9fbdae
--- /dev/null
+++ b/TTS/vocoder/configs/wavernn_config.json
@@ -0,0 +1,97 @@
+{
+ "run_name": "wavernn_test",
+ "run_description": "wavernn_test training",
+
+// AUDIO PARAMETERS
+ "audio": {
+ "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
+ "win_length": 1024, // stft window length in ms.
+ "hop_length": 256, // stft window hop-lengh in ms.
+ "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
+ "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
+ // Audio processing parameters
+ "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
+ "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
+ "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
+ // Silence trimming
+ "do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
+ "trim_db": 60, // threshold for timming silence. Set this according to your dataset.
+ // MelSpectrogram parameters
+ "num_mels": 80, // size of the mel spec frame.
+ "mel_fmin": 40.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
+ "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
+ "spec_gain": 20.0, // scaler value appplied after log transform of spectrogram.
+ // Normalization parameters
+ "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
+ "min_level_db": -100, // lower bound for normalization
+ "symmetric_norm": true, // move normalization to range [-1, 1]
+ "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
+ "clip_norm": true, // clip normalized values into the range.
+ "stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
+ },
+
+// Generating / Synthesizing
+ "batched": true,
+ "target_samples": 11000, // target number of samples to be generated in each batch entry
+ "overlap_samples": 550, // number of samples for crossfading between batches
+ // DISTRIBUTED TRAINING
+ // "distributed":{
+ // "backend": "nccl",
+ // "url": "tcp:\/\/localhost:54321"
+ // },
+
+// MODEL MODE
+ "mode": 10, // mold [string], gauss [string], bits [int]
+ "mulaw": true, // apply mulaw if mode is bits
+
+// MODEL PARAMETERS
+ "wavernn_model_params": {
+ "rnn_dims": 512,
+ "fc_dims": 512,
+ "compute_dims": 128,
+ "res_out_dims": 128,
+ "num_res_blocks": 10,
+ "use_aux_net": true,
+ "use_upsample_net": true,
+ "upsample_factors": [4, 8, 8] // this needs to correctly factorise hop_length
+ },
+
+// DATASET
+ //"use_gta": true, // use computed gta features from the tts model
+ "data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech", // path containing training wav files
+ "feature_path": null, // path containing computed features from wav files if null compute them
+ "seq_len": 1280, // has to be devideable by hop_length
+ "padding": 2, // pad the input for resnet to see wider input length
+
+// TRAINING
+ "batch_size": 64, // Batch size for training.
+ "epochs": 10000, // total number of epochs to train.
+
+// VALIDATION
+ "run_eval": true,
+ "test_every_epochs": 10, // Test after set number of epochs (Test every 10 epochs for example)
+
+// OPTIMIZER
+ "grad_clip": 4, // apply gradient clipping if > 0
+ "lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
+ "lr_scheduler_params": {
+ "gamma": 0.5,
+ "milestones": [200000, 400000, 600000]
+ },
+ "lr": 1e-4, // initial learning rate
+
+// TENSORBOARD and LOGGING
+ "print_step": 25, // Number of steps to log traning on console.
+ "print_eval": false, // If True, it prints loss values for each step in eval run.
+ "save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
+ "checkpoint": true, // If true, it saves checkpoints per "save_step"
+ "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
+
+// DATA LOADING
+ "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
+ "num_val_loader_workers": 4, // number of evaluation data loader processes.
+ "eval_split_size": 50, // number of samples for testing
+
+// PATHS
+ "output_path": "output/training/path"
+}
diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py
index be60c13a..afea45fd 100644
--- a/TTS/vocoder/datasets/preprocess.py
+++ b/TTS/vocoder/datasets/preprocess.py
@@ -1,17 +1,38 @@
import glob
import os
from pathlib import Path
+from tqdm import tqdm
import numpy as np
+def preprocess_wav_files(out_path, config, ap):
+ os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
+ os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
+ wav_files = find_wav_files(config.data_path)
+ for path in tqdm(wav_files):
+ wav_name = Path(path).stem
+ quant_path = os.path.join(out_path, "quant", wav_name + ".npy")
+ mel_path = os.path.join(out_path, "mel", wav_name + ".npy")
+ y = ap.load_wav(path)
+ mel = ap.melspectrogram(y)
+ np.save(mel_path, mel)
+ if isinstance(config.mode, int):
+ quant = (
+ ap.mulaw_encode(y, qc=config.mode)
+ if config.mulaw
+ else ap.quantize(y, bits=config.mode)
+ )
+ np.save(quant_path, quant)
+
+
def find_wav_files(data_path):
- wav_paths = glob.glob(os.path.join(data_path, '**', '*.wav'), recursive=True)
+ wav_paths = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
return wav_paths
def find_feat_files(data_path):
- feat_paths = glob.glob(os.path.join(data_path, '**', '*.npy'), recursive=True)
+ feat_paths = glob.glob(os.path.join(data_path, "**", "*.npy"), recursive=True)
return feat_paths
@@ -23,8 +44,12 @@ def load_wav_data(data_path, eval_split_size):
def load_wav_feat_data(data_path, feat_path, eval_split_size):
- wav_paths = sorted(find_wav_files(data_path))
- feat_paths = sorted(find_feat_files(feat_path))
+ wav_paths = find_wav_files(data_path)
+ feat_paths = find_feat_files(feat_path)
+
+ wav_paths.sort(key=lambda x: Path(x).stem)
+ feat_paths.sort(key=lambda x: Path(x).stem)
+
assert len(wav_paths) == len(feat_paths)
for wav, feat in zip(wav_paths, feat_paths):
wav_name = Path(wav).stem
diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py
new file mode 100644
index 00000000..9c1ded96
--- /dev/null
+++ b/TTS/vocoder/datasets/wavernn_dataset.py
@@ -0,0 +1,115 @@
+import torch
+import numpy as np
+from torch.utils.data import Dataset
+
+
+class WaveRNNDataset(Dataset):
+ """
+ WaveRNN Dataset searchs for all the wav files under root path
+ and converts them to acoustic features on the fly.
+ """
+
+ def __init__(self,
+ ap,
+ items,
+ seq_len,
+ hop_len,
+ pad,
+ mode,
+ mulaw,
+ is_training=True,
+ verbose=False,
+ ):
+
+ self.ap = ap
+ self.compute_feat = not isinstance(items[0], (tuple, list))
+ self.item_list = items
+ self.seq_len = seq_len
+ self.hop_len = hop_len
+ self.pad = pad
+ self.mode = mode
+ self.mulaw = mulaw
+ self.is_training = is_training
+ self.verbose = verbose
+
+ def __len__(self):
+ return len(self.item_list)
+
+ def __getitem__(self, index):
+ item = self.load_item(index)
+ return item
+
+ def load_item(self, index):
+ """
+ load (audio, feat) couple if feature_path is set
+ else compute it on the fly
+ """
+ if self.compute_feat:
+
+ wavpath = self.item_list[index]
+ audio = self.ap.load_wav(wavpath)
+ mel = self.ap.melspectrogram(audio)
+
+ if mel.shape[-1] < 5:
+ print(" [!] Instance is too short! : {}".format(wavpath))
+ self.item_list[index] = self.item_list[index + 1]
+ audio = self.ap.load_wav(wavpath)
+ mel = self.ap.melspectrogram(audio)
+ if self.mode in ["gauss", "mold"]:
+ x_input = audio
+ elif isinstance(self.mode, int):
+ x_input = (self.ap.mulaw_encode(audio, qc=self.mode)
+ if self.mulaw else self.ap.quantize(audio, bits=self.mode))
+ else:
+ raise RuntimeError("Unknown dataset mode - ", self.mode)
+
+ else:
+
+ wavpath, feat_path = self.item_list[index]
+ mel = np.load(feat_path.replace("/quant/", "/mel/"))
+
+ if mel.shape[-1] < 5:
+ print(" [!] Instance is too short! : {}".format(wavpath))
+ self.item_list[index] = self.item_list[index + 1]
+ feat_path = self.item_list[index]
+ mel = np.load(feat_path.replace("/quant/", "/mel/"))
+ if self.mode in ["gauss", "mold"]:
+ x_input = self.ap.load_wav(wavpath)
+ elif isinstance(self.mode, int):
+ x_input = np.load(feat_path.replace("/mel/", "/quant/"))
+ else:
+ raise RuntimeError("Unknown dataset mode - ", self.mode)
+
+ return mel, x_input
+
+ def collate(self, batch):
+ mel_win = self.seq_len // self.hop_len + 2 * self.pad
+ max_offsets = [x[0].shape[-1] -
+ (mel_win + 2 * self.pad) for x in batch]
+ mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
+ sig_offsets = [(offset + self.pad) *
+ self.hop_len for offset in mel_offsets]
+
+ mels = [
+ x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win]
+ for i, x in enumerate(batch)
+ ]
+
+ coarse = [
+ x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1]
+ for i, x in enumerate(batch)
+ ]
+
+ mels = np.stack(mels).astype(np.float32)
+ if self.mode in ["gauss", "mold"]:
+ coarse = np.stack(coarse).astype(np.float32)
+ coarse = torch.FloatTensor(coarse)
+ x_input = coarse[:, : self.seq_len]
+ elif isinstance(self.mode, int):
+ coarse = np.stack(coarse).astype(np.int64)
+ coarse = torch.LongTensor(coarse)
+ x_input = (2 * coarse[:, : self.seq_len].float() /
+ (2 ** self.mode - 1.0) - 1.0)
+ y_coarse = coarse[:, 1:]
+ mels = torch.FloatTensor(mels)
+ return x_input, mels, y_coarse
diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py
new file mode 100644
index 00000000..f771175c
--- /dev/null
+++ b/TTS/vocoder/models/wavernn.py
@@ -0,0 +1,496 @@
+import sys
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+import time
+
+# fix this
+from TTS.utils.audio import AudioProcessor as ap
+from TTS.vocoder.utils.distribution import (
+ sample_from_gaussian,
+ sample_from_discretized_mix_logistic,
+)
+
+
+def stream(string, variables):
+ sys.stdout.write(f"\r{string}" % variables)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, dims):
+ super().__init__()
+ self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
+ self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
+ self.batch_norm1 = nn.BatchNorm1d(dims)
+ self.batch_norm2 = nn.BatchNorm1d(dims)
+
+ def forward(self, x):
+ residual = x
+ x = self.conv1(x)
+ x = self.batch_norm1(x)
+ x = F.relu(x)
+ x = self.conv2(x)
+ x = self.batch_norm2(x)
+ return x + residual
+
+
+class MelResNet(nn.Module):
+ def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad):
+ super().__init__()
+ k_size = pad * 2 + 1
+ self.conv_in = nn.Conv1d(
+ in_dims, compute_dims, kernel_size=k_size, bias=False)
+ self.batch_norm = nn.BatchNorm1d(compute_dims)
+ self.layers = nn.ModuleList()
+ for _ in range(num_res_blocks):
+ self.layers.append(ResBlock(compute_dims))
+ self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ x = self.batch_norm(x)
+ x = F.relu(x)
+ for f in self.layers:
+ x = f(x)
+ x = self.conv_out(x)
+ return x
+
+
+class Stretch2d(nn.Module):
+ def __init__(self, x_scale, y_scale):
+ super().__init__()
+ self.x_scale = x_scale
+ self.y_scale = y_scale
+
+ def forward(self, x):
+ b, c, h, w = x.size()
+ x = x.unsqueeze(-1).unsqueeze(3)
+ x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
+ return x.view(b, c, h * self.y_scale, w * self.x_scale)
+
+
+class UpsampleNetwork(nn.Module):
+ def __init__(
+ self,
+ feat_dims,
+ upsample_scales,
+ compute_dims,
+ num_res_blocks,
+ res_out_dims,
+ pad,
+ use_aux_net,
+ ):
+ super().__init__()
+ self.total_scale = np.cumproduct(upsample_scales)[-1]
+ self.indent = pad * self.total_scale
+ self.use_aux_net = use_aux_net
+ if use_aux_net:
+ self.resnet = MelResNet(
+ num_res_blocks, feat_dims, compute_dims, res_out_dims, pad
+ )
+ self.resnet_stretch = Stretch2d(self.total_scale, 1)
+ self.up_layers = nn.ModuleList()
+ for scale in upsample_scales:
+ k_size = (1, scale * 2 + 1)
+ padding = (0, scale)
+ stretch = Stretch2d(scale, 1)
+ conv = nn.Conv2d(1, 1, kernel_size=k_size,
+ padding=padding, bias=False)
+ conv.weight.data.fill_(1.0 / k_size[1])
+ self.up_layers.append(stretch)
+ self.up_layers.append(conv)
+
+ def forward(self, m):
+ if self.use_aux_net:
+ aux = self.resnet(m).unsqueeze(1)
+ aux = self.resnet_stretch(aux)
+ aux = aux.squeeze(1)
+ aux = aux.transpose(1, 2)
+ else:
+ aux = None
+ m = m.unsqueeze(1)
+ for f in self.up_layers:
+ m = f(m)
+ m = m.squeeze(1)[:, :, self.indent: -self.indent]
+ return m.transpose(1, 2), aux
+
+
+class Upsample(nn.Module):
+ def __init__(
+ self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net
+ ):
+ super().__init__()
+ self.scale = scale
+ self.pad = pad
+ self.indent = pad * scale
+ self.use_aux_net = use_aux_net
+ self.resnet = MelResNet(num_res_blocks, feat_dims,
+ compute_dims, res_out_dims, pad)
+
+ def forward(self, m):
+ if self.use_aux_net:
+ aux = self.resnet(m)
+ aux = torch.nn.functional.interpolate(
+ aux, scale_factor=self.scale, mode="linear", align_corners=True
+ )
+ aux = aux.transpose(1, 2)
+ else:
+ aux = None
+ m = torch.nn.functional.interpolate(
+ m, scale_factor=self.scale, mode="linear", align_corners=True
+ )
+ m = m[:, :, self.indent: -self.indent]
+ m = m * 0.045 # empirically found
+
+ return m.transpose(1, 2), aux
+
+
+class WaveRNN(nn.Module):
+ def __init__(self,
+ rnn_dims,
+ fc_dims,
+ mode,
+ mulaw,
+ pad,
+ use_aux_net,
+ use_upsample_net,
+ upsample_factors,
+ feat_dims,
+ compute_dims,
+ res_out_dims,
+ num_res_blocks,
+ hop_length,
+ sample_rate,
+ ):
+ super().__init__()
+ self.mode = mode
+ self.mulaw = mulaw
+ self.pad = pad
+ self.use_upsample_net = use_upsample_net
+ self.use_aux_net = use_aux_net
+ if isinstance(self.mode, int):
+ self.n_classes = 2 ** self.mode
+ elif self.mode == "mold":
+ self.n_classes = 3 * 10
+ elif self.mode == "gauss":
+ self.n_classes = 2
+ else:
+ raise RuntimeError("Unknown model mode value - ", self.mode)
+
+ self.rnn_dims = rnn_dims
+ self.aux_dims = res_out_dims // 4
+ self.hop_length = hop_length
+ self.sample_rate = sample_rate
+
+ if self.use_upsample_net:
+ assert (
+ np.cumproduct(upsample_factors)[-1] == self.hop_length
+ ), " [!] upsample scales needs to be equal to hop_length"
+ self.upsample = UpsampleNetwork(
+ feat_dims,
+ upsample_factors,
+ compute_dims,
+ num_res_blocks,
+ res_out_dims,
+ pad,
+ use_aux_net,
+ )
+ else:
+ self.upsample = Upsample(
+ hop_length,
+ pad,
+ num_res_blocks,
+ feat_dims,
+ compute_dims,
+ res_out_dims,
+ use_aux_net,
+ )
+ if self.use_aux_net:
+ self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
+ self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
+ self.rnn2 = nn.GRU(rnn_dims + self.aux_dims,
+ rnn_dims, batch_first=True)
+ self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
+ self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
+ self.fc3 = nn.Linear(fc_dims, self.n_classes)
+ else:
+ self.I = nn.Linear(feat_dims + 1, rnn_dims)
+ self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
+ self.rnn2 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
+ self.fc1 = nn.Linear(rnn_dims, fc_dims)
+ self.fc2 = nn.Linear(fc_dims, fc_dims)
+ self.fc3 = nn.Linear(fc_dims, self.n_classes)
+
+ def forward(self, x, mels):
+ bsize = x.size(0)
+ h1 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
+ h2 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
+ mels, aux = self.upsample(mels)
+
+ if self.use_aux_net:
+ aux_idx = [self.aux_dims * i for i in range(5)]
+ a1 = aux[:, :, aux_idx[0]: aux_idx[1]]
+ a2 = aux[:, :, aux_idx[1]: aux_idx[2]]
+ a3 = aux[:, :, aux_idx[2]: aux_idx[3]]
+ a4 = aux[:, :, aux_idx[3]: aux_idx[4]]
+
+ x = (
+ torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
+ if self.use_aux_net
+ else torch.cat([x.unsqueeze(-1), mels], dim=2)
+ )
+ x = self.I(x)
+ res = x
+ self.rnn1.flatten_parameters()
+ x, _ = self.rnn1(x, h1)
+
+ x = x + res
+ res = x
+ x = torch.cat([x, a2], dim=2) if self.use_aux_net else x
+ self.rnn2.flatten_parameters()
+ x, _ = self.rnn2(x, h2)
+
+ x = x + res
+ x = torch.cat([x, a3], dim=2) if self.use_aux_net else x
+ x = F.relu(self.fc1(x))
+
+ x = torch.cat([x, a4], dim=2) if self.use_aux_net else x
+ x = F.relu(self.fc2(x))
+ return self.fc3(x)
+
+ def generate(self, mels, batched, target, overlap, use_cuda=False):
+
+ self.eval()
+ device = 'cuda' if use_cuda else 'cpu'
+ output = []
+ start = time.time()
+ rnn1 = self.get_gru_cell(self.rnn1)
+ rnn2 = self.get_gru_cell(self.rnn2)
+
+ with torch.no_grad():
+ mels = torch.FloatTensor(mels).unsqueeze(0).to(device)
+ #mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
+ wave_len = (mels.size(-1) - 1) * self.hop_length
+ mels = self.pad_tensor(mels.transpose(
+ 1, 2), pad=self.pad, side="both")
+ mels, aux = self.upsample(mels.transpose(1, 2))
+
+ if batched:
+ mels = self.fold_with_overlap(mels, target, overlap)
+ if aux is not None:
+ aux = self.fold_with_overlap(aux, target, overlap)
+
+ b_size, seq_len, _ = mels.size()
+
+ h1 = torch.zeros(b_size, self.rnn_dims).to(device)
+ h2 = torch.zeros(b_size, self.rnn_dims).to(device)
+ x = torch.zeros(b_size, 1).to(device)
+
+ if self.use_aux_net:
+ d = self.aux_dims
+ aux_split = [aux[:, :, d * i: d * (i + 1)] for i in range(4)]
+
+ for i in range(seq_len):
+
+ m_t = mels[:, i, :]
+
+ if self.use_aux_net:
+ a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
+
+ x = (
+ torch.cat([x, m_t, a1_t], dim=1)
+ if self.use_aux_net
+ else torch.cat([x, m_t], dim=1)
+ )
+ x = self.I(x)
+ h1 = rnn1(x, h1)
+
+ x = x + h1
+ inp = torch.cat([x, a2_t], dim=1) if self.use_aux_net else x
+ h2 = rnn2(inp, h2)
+
+ x = x + h2
+ x = torch.cat([x, a3_t], dim=1) if self.use_aux_net else x
+ x = F.relu(self.fc1(x))
+
+ x = torch.cat([x, a4_t], dim=1) if self.use_aux_net else x
+ x = F.relu(self.fc2(x))
+
+ logits = self.fc3(x)
+
+ if self.mode == "mold":
+ sample = sample_from_discretized_mix_logistic(
+ logits.unsqueeze(0).transpose(1, 2)
+ )
+ output.append(sample.view(-1))
+ x = sample.transpose(0, 1).to(device)
+ elif self.mode == "gauss":
+ sample = sample_from_gaussian(
+ logits.unsqueeze(0).transpose(1, 2))
+ output.append(sample.view(-1))
+ x = sample.transpose(0, 1).to(device)
+ elif isinstance(self.mode, int):
+ posterior = F.softmax(logits, dim=1)
+ distrib = torch.distributions.Categorical(posterior)
+
+ sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0
+ output.append(sample)
+ x = sample.unsqueeze(-1)
+ else:
+ raise RuntimeError(
+ "Unknown model mode value - ", self.mode)
+
+ if i % 100 == 0:
+ self.gen_display(i, seq_len, b_size, start)
+
+ output = torch.stack(output).transpose(0, 1)
+ output = output.cpu().numpy()
+ output = output.astype(np.float64)
+
+ if batched:
+ output = self.xfade_and_unfold(output, target, overlap)
+ else:
+ output = output[0]
+
+ if self.mulaw and isinstance(self.mode, int):
+ output = ap.mulaw_decode(output, self.mode)
+
+ # Fade-out at the end to avoid signal cutting out suddenly
+ fade_out = np.linspace(1, 0, 20 * self.hop_length)
+ output = output[:wave_len]
+
+ if wave_len > len(fade_out):
+ output[-20 * self.hop_length:] *= fade_out
+
+ self.train()
+ return output
+
+ def gen_display(self, i, seq_len, b_size, start):
+ gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
+ realtime_ratio = gen_rate * 1000 / self.sample_rate
+ stream(
+ "%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ",
+ (i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
+ )
+
+ def fold_with_overlap(self, x, target, overlap):
+ """Fold the tensor with overlap for quick batched inference.
+ Overlap will be used for crossfading in xfade_and_unfold()
+ Args:
+ x (tensor) : Upsampled conditioning features.
+ shape=(1, timesteps, features)
+ target (int) : Target timesteps for each index of batch
+ overlap (int) : Timesteps for both xfade and rnn warmup
+ Return:
+ (tensor) : shape=(num_folds, target + 2 * overlap, features)
+ Details:
+ x = [[h1, h2, ... hn]]
+ Where each h is a vector of conditioning features
+ Eg: target=2, overlap=1 with x.size(1)=10
+ folded = [[h1, h2, h3, h4],
+ [h4, h5, h6, h7],
+ [h7, h8, h9, h10]]
+ """
+
+ _, total_len, features = x.size()
+
+ # Calculate variables needed
+ num_folds = (total_len - overlap) // (target + overlap)
+ extended_len = num_folds * (overlap + target) + overlap
+ remaining = total_len - extended_len
+
+ # Pad if some time steps poking out
+ if remaining != 0:
+ num_folds += 1
+ padding = target + 2 * overlap - remaining
+ x = self.pad_tensor(x, padding, side="after")
+
+ folded = torch.zeros(num_folds, target + 2 *
+ overlap, features).to(x.device)
+
+ # Get the values for the folded tensor
+ for i in range(num_folds):
+ start = i * (target + overlap)
+ end = start + target + 2 * overlap
+ folded[i] = x[:, start:end, :]
+
+ return folded
+
+ @staticmethod
+ def get_gru_cell(gru):
+ gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
+ gru_cell.weight_hh.data = gru.weight_hh_l0.data
+ gru_cell.weight_ih.data = gru.weight_ih_l0.data
+ gru_cell.bias_hh.data = gru.bias_hh_l0.data
+ gru_cell.bias_ih.data = gru.bias_ih_l0.data
+ return gru_cell
+
+ @staticmethod
+ def pad_tensor(x, pad, side="both"):
+ # NB - this is just a quick method i need right now
+ # i.e., it won't generalise to other shapes/dims
+ b, t, c = x.size()
+ total = t + 2 * pad if side == "both" else t + pad
+ padded = torch.zeros(b, total, c).to(x.device)
+ if side in ("before", "both"):
+ padded[:, pad: pad + t, :] = x
+ elif side == "after":
+ padded[:, :t, :] = x
+ return padded
+
+ @staticmethod
+ def xfade_and_unfold(y, target, overlap):
+ """Applies a crossfade and unfolds into a 1d array.
+ Args:
+ y (ndarry) : Batched sequences of audio samples
+ shape=(num_folds, target + 2 * overlap)
+ dtype=np.float64
+ overlap (int) : Timesteps for both xfade and rnn warmup
+ Return:
+ (ndarry) : audio samples in a 1d array
+ shape=(total_len)
+ dtype=np.float64
+ Details:
+ y = [[seq1],
+ [seq2],
+ [seq3]]
+ Apply a gain envelope at both ends of the sequences
+ y = [[seq1_in, seq1_target, seq1_out],
+ [seq2_in, seq2_target, seq2_out],
+ [seq3_in, seq3_target, seq3_out]]
+ Stagger and add up the groups of samples:
+ [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
+ """
+
+ num_folds, length = y.shape
+ target = length - 2 * overlap
+ total_len = num_folds * (target + overlap) + overlap
+
+ # Need some silence for the rnn warmup
+ silence_len = overlap // 2
+ fade_len = overlap - silence_len
+ silence = np.zeros((silence_len), dtype=np.float64)
+
+ # Equal power crossfade
+ t = np.linspace(-1, 1, fade_len, dtype=np.float64)
+ fade_in = np.sqrt(0.5 * (1 + t))
+ fade_out = np.sqrt(0.5 * (1 - t))
+
+ # Concat the silence to the fades
+ fade_in = np.concatenate([silence, fade_in])
+ fade_out = np.concatenate([fade_out, silence])
+
+ # Apply the gain to the overlap samples
+ y[:, :overlap] *= fade_in
+ y[:, -overlap:] *= fade_out
+
+ unfolded = np.zeros((total_len), dtype=np.float64)
+
+ # Loop to add up all the samples
+ for i in range(num_folds):
+ start = i * (target + overlap)
+ end = start + target + 2 * overlap
+ unfolded[start:end] += y[i]
+
+ return unfolded
diff --git a/TTS/vocoder/utils/distribution.py b/TTS/vocoder/utils/distribution.py
new file mode 100644
index 00000000..6aba5e34
--- /dev/null
+++ b/TTS/vocoder/utils/distribution.py
@@ -0,0 +1,168 @@
+import numpy as np
+import math
+import torch
+from torch.distributions.normal import Normal
+import torch.nn.functional as F
+
+
+def gaussian_loss(y_hat, y, log_std_min=-7.0):
+ assert y_hat.dim() == 3
+ assert y_hat.size(2) == 2
+ mean = y_hat[:, :, :1]
+ log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
+ # TODO: replace with pytorch dist
+ log_probs = -0.5 * (
+ -math.log(2.0 * math.pi)
+ - 2.0 * log_std
+ - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std))
+ )
+ return log_probs.squeeze().mean()
+
+
+def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0):
+ assert y_hat.size(2) == 2
+ mean = y_hat[:, :, :1]
+ log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
+ dist = Normal(
+ mean,
+ torch.exp(log_std),
+ )
+ sample = dist.sample()
+ sample = torch.clamp(torch.clamp(
+ sample, min=-scale_factor), max=scale_factor)
+ del dist
+ return sample
+
+
+def log_sum_exp(x):
+ """ numerically stable log_sum_exp implementation that prevents overflow """
+ # TF ordering
+ axis = len(x.size()) - 1
+ m, _ = torch.max(x, dim=axis)
+ m2, _ = torch.max(x, dim=axis, keepdim=True)
+ return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
+
+
+# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
+def discretized_mix_logistic_loss(
+ y_hat, y, num_classes=65536, log_scale_min=None, reduce=True
+):
+ if log_scale_min is None:
+ log_scale_min = float(np.log(1e-14))
+ y_hat = y_hat.permute(0, 2, 1)
+ assert y_hat.dim() == 3
+ assert y_hat.size(1) % 3 == 0
+ nr_mix = y_hat.size(1) // 3
+
+ # (B x T x C)
+ y_hat = y_hat.transpose(1, 2)
+
+ # unpack parameters. (B, T, num_mixtures) x 3
+ logit_probs = y_hat[:, :, :nr_mix]
+ means = y_hat[:, :, nr_mix: 2 * nr_mix]
+ log_scales = torch.clamp(
+ y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=log_scale_min)
+
+ # B x T x 1 -> B x T x num_mixtures
+ y = y.expand_as(means)
+
+ centered_y = y - means
+ inv_stdv = torch.exp(-log_scales)
+ plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1))
+ cdf_plus = torch.sigmoid(plus_in)
+ min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1))
+ cdf_min = torch.sigmoid(min_in)
+
+ # log probability for edge case of 0 (before scaling)
+ # equivalent: torch.log(F.sigmoid(plus_in))
+ log_cdf_plus = plus_in - F.softplus(plus_in)
+
+ # log probability for edge case of 255 (before scaling)
+ # equivalent: (1 - F.sigmoid(min_in)).log()
+ log_one_minus_cdf_min = -F.softplus(min_in)
+
+ # probability for all other cases
+ cdf_delta = cdf_plus - cdf_min
+
+ mid_in = inv_stdv * centered_y
+ # log probability in the center of the bin, to be used in extreme cases
+ # (not actually used in our code)
+ log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
+
+ # tf equivalent
+
+ # log_probs = tf.where(x < -0.999, log_cdf_plus,
+ # tf.where(x > 0.999, log_one_minus_cdf_min,
+ # tf.where(cdf_delta > 1e-5,
+ # tf.log(tf.maximum(cdf_delta, 1e-12)),
+ # log_pdf_mid - np.log(127.5))))
+
+ # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
+ # for num_classes=65536 case? 1e-7? not sure..
+ inner_inner_cond = (cdf_delta > 1e-5).float()
+
+ inner_inner_out = inner_inner_cond * torch.log(
+ torch.clamp(cdf_delta, min=1e-12)
+ ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
+ inner_cond = (y > 0.999).float()
+ inner_out = (
+ inner_cond * log_one_minus_cdf_min +
+ (1.0 - inner_cond) * inner_inner_out
+ )
+ cond = (y < -0.999).float()
+ log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
+
+ log_probs = log_probs + F.log_softmax(logit_probs, -1)
+
+ if reduce:
+ return -torch.mean(log_sum_exp(log_probs))
+ return -log_sum_exp(log_probs).unsqueeze(-1)
+
+
+def sample_from_discretized_mix_logistic(y, log_scale_min=None):
+ """
+ Sample from discretized mixture of logistic distributions
+ Args:
+ y (Tensor): B x C x T
+ log_scale_min (float): Log scale minimum value
+ Returns:
+ Tensor: sample in range of [-1, 1].
+ """
+ if log_scale_min is None:
+ log_scale_min = float(np.log(1e-14))
+ assert y.size(1) % 3 == 0
+ nr_mix = y.size(1) // 3
+
+ # B x T x C
+ y = y.transpose(1, 2)
+ logit_probs = y[:, :, :nr_mix]
+
+ # sample mixture indicator from softmax
+ temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
+ temp = logit_probs.data - torch.log(-torch.log(temp))
+ _, argmax = temp.max(dim=-1)
+
+ # (B, T) -> (B, T, nr_mix)
+ one_hot = to_one_hot(argmax, nr_mix)
+ # select logistic parameters
+ means = torch.sum(y[:, :, nr_mix: 2 * nr_mix] * one_hot, dim=-1)
+ log_scales = torch.clamp(
+ torch.sum(y[:, :, 2 * nr_mix: 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min
+ )
+ # sample from logistic & clip to interval
+ # we don't actually round to the nearest 8bit value when sampling
+ u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
+ x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
+
+ x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0)
+
+ return x
+
+
+def to_one_hot(tensor, n, fill_with=1.0):
+ # we perform one hot encore with respect to the last axis
+ one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
+ if tensor.is_cuda:
+ one_hot = one_hot.cuda()
+ one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
+ return one_hot
diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py
index 89dc68fb..f9fbba52 100644
--- a/TTS/vocoder/utils/generic_utils.py
+++ b/TTS/vocoder/utils/generic_utils.py
@@ -42,6 +42,29 @@ def to_camel(text):
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
+def setup_wavernn(c):
+ print(" > Model: WaveRNN")
+ MyModel = importlib.import_module("TTS.vocoder.models.wavernn")
+ MyModel = getattr(MyModel, "WaveRNN")
+ model = MyModel(
+ rnn_dims=c.wavernn_model_params['rnn_dims'],
+ fc_dims=c.wavernn_model_params['fc_dims'],
+ mode=c.mode,
+ mulaw=c.mulaw,
+ pad=c.padding,
+ use_aux_net=c.wavernn_model_params['use_aux_net'],
+ use_upsample_net=c.wavernn_model_params['use_upsample_net'],
+ upsample_factors=c.wavernn_model_params['upsample_factors'],
+ feat_dims=c.audio['num_mels'],
+ compute_dims=c.wavernn_model_params['compute_dims'],
+ res_out_dims=c.wavernn_model_params['res_out_dims'],
+ num_res_blocks=c.wavernn_model_params['num_res_blocks'],
+ hop_length=c.audio["hop_length"],
+ sample_rate=c.audio["sample_rate"],
+ )
+ return model
+
+
def setup_generator(c):
print(" > Generator Model: {}".format(c.generator_model))
MyModel = importlib.import_module('TTS.vocoder.models.' +
diff --git a/tests/inputs/test_vocoder_wavernn_config.json b/tests/inputs/test_vocoder_wavernn_config.json
new file mode 100644
index 00000000..28c0f059
--- /dev/null
+++ b/tests/inputs/test_vocoder_wavernn_config.json
@@ -0,0 +1,94 @@
+{
+ "run_name": "wavernn_test",
+ "run_description": "wavernn_test training",
+
+ // AUDIO PARAMETERS
+ "audio":{
+ "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
+ "win_length": 1024, // stft window length in ms.
+ "hop_length": 256, // stft window hop-lengh in ms.
+ "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
+ "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
+
+ // Audio processing parameters
+ "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
+ "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
+ "ref_level_db": 0, // reference level db, theoretically 20db is the sound of air.
+
+ // Silence trimming
+ "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
+ "trim_db": 60, // threshold for timming silence. Set this according to your dataset.
+
+ // MelSpectrogram parameters
+ "num_mels": 80, // size of the mel spec frame.
+ "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
+ "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
+ "spec_gain": 20.0, // scaler value appplied after log transform of spectrogram.
+
+ // Normalization parameters
+ "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
+ "min_level_db": -100, // lower bound for normalization
+ "symmetric_norm": true, // move normalization to range [-1, 1]
+ "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
+ "clip_norm": true, // clip normalized values into the range.
+ "stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
+ },
+
+ // Generating / Synthesizing
+ "batched": true,
+ "target_samples": 11000, // target number of samples to be generated in each batch entry
+ "overlap_samples": 550, // number of samples for crossfading between batches
+
+ // DISTRIBUTED TRAINING
+ // "distributed":{
+ // "backend": "nccl",
+ // "url": "tcp:\/\/localhost:54321"
+ // },
+
+ // MODEL PARAMETERS
+ "use_aux_net": true,
+ "use_upsample_net": true,
+ "upsample_factors": [4, 8, 8], // this needs to correctly factorise hop_length
+ "seq_len": 1280, // has to be devideable by hop_length
+ "mode": "mold", // mold [string], gauss [string], bits [int]
+ "mulaw": false, // apply mulaw if mode is bits
+ "padding": 2, // pad the input for resnet to see wider input length
+
+ // DATASET
+ //"use_gta": true, // use computed gta features from the tts model
+ "data_path": "tests/data/ljspeech/wavs/", // path containing training wav files
+ "feature_path": null, // path containing computed features from wav files if null compute them
+
+ // TRAINING
+ "batch_size": 4, // Batch size for training. Lower values than 32 might cause hard to learn attention.
+ "epochs": 1, // total number of epochs to train.
+
+ // VALIDATION
+ "run_eval": true,
+ "test_every_epochs": 10, // Test after set number of epochs (Test every 20 epochs for example)
+
+ // OPTIMIZER
+ "grad_clip": 4, // apply gradient clipping if > 0
+ "lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
+ "lr_scheduler_params": {
+ "gamma": 0.5,
+ "milestones": [200000, 400000, 600000]
+ },
+ "lr": 1e-4, // initial learning rate
+
+ // TENSORBOARD and LOGGING
+ "print_step": 25, // Number of steps to log traning on console.
+ "print_eval": false, // If True, it prints loss values for each step in eval run.
+ "save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
+ "checkpoint": true, // If true, it saves checkpoints per "save_step"
+ "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
+
+ // DATA LOADING
+ "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
+ "num_val_loader_workers": 4, // number of evaluation data loader processes.
+ "eval_split_size": 10, // number of samples for testing
+
+ // PATHS
+ "output_path": "tests/train_outputs/"
+}
+
diff --git a/tests/test_vocoder_datasets.py b/tests/test_vocoder_gan_datasets.py
similarity index 100%
rename from tests/test_vocoder_datasets.py
rename to tests/test_vocoder_gan_datasets.py
diff --git a/tests/test_vocoder_train.sh b/tests/test_vocoder_gan_train.sh
similarity index 57%
rename from tests/test_vocoder_train.sh
rename to tests/test_vocoder_gan_train.sh
index fa99b4bd..75773cc3 100755
--- a/tests/test_vocoder_train.sh
+++ b/tests/test_vocoder_gan_train.sh
@@ -5,11 +5,11 @@ echo "$BASEDIR"
# create run dir
mkdir $BASEDIR/train_outputs
# run training
-CUDA_VISIBLE_DEVICES="" python TTS/bin/train_vocoder.py --config_path $BASEDIR/inputs/test_vocoder_multiband_melgan_config.json
+CUDA_VISIBLE_DEVICES="" python TTS/bin/train_gan_vocoder.py --config_path $BASEDIR/inputs/test_vocoder_multiband_melgan_config.json
# find the training folder
LATEST_FOLDER=$(ls $BASEDIR/train_outputs/| sort | tail -1)
echo $LATEST_FOLDER
# continue the previous training
-CUDA_VISIBLE_DEVICES="" python TTS/bin/train_vocoder.py --continue_path $BASEDIR/train_outputs/$LATEST_FOLDER
+CUDA_VISIBLE_DEVICES="" python TTS/bin/train_gan_vocoder.py --continue_path $BASEDIR/train_outputs/$LATEST_FOLDER
# remove all the outputs
rm -rf $BASEDIR/train_outputs/$LATEST_FOLDER
diff --git a/tests/test_vocoder_wavernn.py b/tests/test_vocoder_wavernn.py
new file mode 100644
index 00000000..ccd71c56
--- /dev/null
+++ b/tests/test_vocoder_wavernn.py
@@ -0,0 +1,31 @@
+import numpy as np
+import torch
+import random
+from TTS.vocoder.models.wavernn import WaveRNN
+
+
+def test_wavernn():
+ model = WaveRNN(
+ rnn_dims=512,
+ fc_dims=512,
+ mode=10,
+ mulaw=False,
+ pad=2,
+ use_aux_net=True,
+ use_upsample_net=True,
+ upsample_factors=[4, 8, 8],
+ feat_dims=80,
+ compute_dims=128,
+ res_out_dims=128,
+ num_res_blocks=10,
+ hop_length=256,
+ sample_rate=22050,
+ )
+ dummy_x = torch.rand((2, 1280))
+ dummy_m = torch.rand((2, 80, 9))
+ y_size = random.randrange(20, 60)
+ dummy_y = torch.rand((80, y_size))
+ output = model(dummy_x, dummy_m)
+ assert np.all(output.shape == (2, 1280, 4 * 256)), output.shape
+ output = model.generate(dummy_y, True, 5500, 550, False)
+ assert np.all(output.shape == (256 * (y_size - 1),))
diff --git a/tests/test_vocoder_wavernn_datasets.py b/tests/test_vocoder_wavernn_datasets.py
new file mode 100644
index 00000000..a95e247a
--- /dev/null
+++ b/tests/test_vocoder_wavernn_datasets.py
@@ -0,0 +1,92 @@
+import os
+import shutil
+
+import numpy as np
+from tests import get_tests_path, get_tests_input_path, get_tests_output_path
+from torch.utils.data import DataLoader
+
+from TTS.utils.audio import AudioProcessor
+from TTS.utils.io import load_config
+from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
+from TTS.vocoder.datasets.preprocess import load_wav_feat_data, preprocess_wav_files
+
+file_path = os.path.dirname(os.path.realpath(__file__))
+OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
+os.makedirs(OUTPATH, exist_ok=True)
+
+C = load_config(os.path.join(get_tests_input_path(),
+ "test_vocoder_wavernn_config.json"))
+
+test_data_path = os.path.join(get_tests_path(), "data/ljspeech/")
+test_mel_feat_path = os.path.join(test_data_path, "mel")
+test_quant_feat_path = os.path.join(test_data_path, "quant")
+ok_ljspeech = os.path.exists(test_data_path)
+
+
+def wavernn_dataset_case(batch_size, seq_len, hop_len, pad, mode, mulaw, num_workers):
+ """ run dataloader with given parameters and check conditions """
+ ap = AudioProcessor(**C.audio)
+
+ C.batch_size = batch_size
+ C.mode = mode
+ C.seq_len = seq_len
+ C.data_path = test_data_path
+
+ preprocess_wav_files(test_data_path, C, ap)
+ _, train_items = load_wav_feat_data(
+ test_data_path, test_mel_feat_path, 5)
+
+ dataset = WaveRNNDataset(ap=ap,
+ items=train_items,
+ seq_len=seq_len,
+ hop_len=hop_len,
+ pad=pad,
+ mode=mode,
+ mulaw=mulaw
+ )
+ # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
+ loader = DataLoader(dataset,
+ shuffle=True,
+ collate_fn=dataset.collate,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_memory=True,
+ )
+
+ max_iter = 10
+ count_iter = 0
+
+ try:
+ for data in loader:
+ x_input, mels, _ = data
+ expected_feat_shape = (ap.num_mels,
+ (x_input.shape[-1] // hop_len) + (pad * 2))
+ assert np.all(
+ mels.shape[1:] == expected_feat_shape), f" [!] {mels.shape} vs {expected_feat_shape}"
+
+ assert (mels.shape[2] - pad * 2) * hop_len == x_input.shape[1]
+ count_iter += 1
+ if count_iter == max_iter:
+ break
+ # except AssertionError:
+ # shutil.rmtree(test_mel_feat_path)
+ # shutil.rmtree(test_quant_feat_path)
+ finally:
+ shutil.rmtree(test_mel_feat_path)
+ shutil.rmtree(test_quant_feat_path)
+
+
+def test_parametrized_wavernn_dataset():
+ ''' test dataloader with different parameters '''
+ params = [
+ [16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 10, True, 0],
+ [16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, "mold", False, 4],
+ [1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 9, False, 0],
+ [1, C.audio['hop_length'], C.audio['hop_length'], 2, 10, True, 0],
+ [1, C.audio['hop_length'], C.audio['hop_length'], 2, "mold", False, 0],
+ [1, C.audio['hop_length'] * 5, C.audio['hop_length'], 4, 10, False, 2],
+ [1, C.audio['hop_length'] * 5, C.audio['hop_length'], 2, "mold", False, 0],
+ ]
+ for param in params:
+ print(param)
+ wavernn_dataset_case(*param)
diff --git a/tests/test_vocoder_wavernn_train.sh b/tests/test_vocoder_wavernn_train.sh
new file mode 100755
index 00000000..f2e32116
--- /dev/null
+++ b/tests/test_vocoder_wavernn_train.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+
+BASEDIR=$(dirname "$0")
+echo "$BASEDIR"
+# create run dir
+mkdir $BASEDIR/train_outputs
+# run training
+CUDA_VISIBLE_DEVICES="" python TTS/bin/train_wavernn_vocoder.py --config_path $BASEDIR/inputs/test_vocoder_wavernn_config.json
+# find the training folder
+LATEST_FOLDER=$(ls $BASEDIR/train_outputs/| sort | tail -1)
+echo $LATEST_FOLDER
+# continue the previous training
+CUDA_VISIBLE_DEVICES="" python TTS/bin/train_wavernn_vocoder.py --continue_path $BASEDIR/train_outputs/$LATEST_FOLDER
+# remove all the outputs
+rm -rf $BASEDIR/train_outputs/$LATEST_FOLDER
\ No newline at end of file