diff --git a/TTS/bin/train_wavegrad.py b/TTS/bin/train_wavegrad.py new file mode 100644 index 00000000..469df638 --- /dev/null +++ b/TTS/bin/train_wavegrad.py @@ -0,0 +1,490 @@ +import argparse +import glob +import os +import sys +import time +import traceback +from inspect import signature + +import torch +from torch.utils.data import DataLoader +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler + +from TTS.utils.audio import AudioProcessor +from TTS.utils.console_logger import ConsoleLogger +from TTS.utils.generic_utils import (KeepAverage, count_parameters, + create_experiment_folder, get_git_branch, + remove_experiment_folder, set_init_dict) +from TTS.utils.io import copy_config_file, load_config +from TTS.utils.radam import RAdam +from TTS.utils.tensorboard_logger import TensorboardLogger +from TTS.utils.training import setup_torch_training_env +from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data +from TTS.utils.distribute import init_distributed, reduce_tensor +from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss +from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator, + setup_generator) +from TTS.vocoder.utils.io import save_best_model, save_checkpoint + +use_cuda, num_gpus = setup_torch_training_env(True, True) + + +def setup_loader(ap, is_val=False, verbose=False): + if is_val and not c.run_eval: + loader = None + else: + dataset = WaveGradDataset(ap=ap, + items=eval_data if is_val else train_data, + seq_len=c.seq_len, + hop_len=ap.hop_length, + pad_short=c.pad_short, + conv_pad=c.conv_pad, + is_training=not is_val, + return_segments=True, + use_noise_augment=False, + use_cache=c.use_cache, + verbose=verbose) + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + loader = DataLoader(dataset, + batch_size=c.batch_size, + shuffle=False if num_gpus > 1 else True, + drop_last=False, + sampler=sampler, + num_workers=c.num_val_loader_workers + if is_val else c.num_loader_workers, + pin_memory=False) + return loader + + +def format_data(data): + # return a whole audio segment + m, y = data + if use_cuda: + m = m.cuda(non_blocking=True) + y = y.cuda(non_blocking=True) + return m, y + + +def train(model, criterion, optimizer, + scheduler, ap, global_step, epoch, amp): + data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) + model.train() + epoch_time = 0 + keep_avg = KeepAverage() + if use_cuda: + batch_n_iter = int( + len(data_loader.dataset) / (c.batch_size * num_gpus)) + else: + batch_n_iter = int(len(data_loader.dataset) / c.batch_size) + end_time = time.time() + c_logger.print_train_start() + for num_iter, data in enumerate(data_loader): + start_time = time.time() + + # format data + m, y = format_data(data) + loader_time = time.time() - end_time + + global_step += 1 + + # compute noisy input + if hasattr(model, 'module'): + y_noisy, noise_scale = model.module.compute_noisy_x(y) + else: + y_noisy, noise_scale = model.compute_noisy_x(y) + + # forward pass + y_hat = model(y_noisy, m, noise_scale) + + # compute losses + loss = criterion(y_noisy, y_hat) + loss_wavegrad_dict = {'wavegrad_loss':loss} + + # backward pass with loss scaling + optimizer.zero_grad() + + if amp is not None: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + if amp: + amp_opt_params = amp.master_params(optimizer) + else: + amp_opt_params = None + + if c.clip_grad > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), + c.clip_grad) + optimizer.step() + + # schedule update + if scheduler is not None: + scheduler.step() + + # disconnect loss values + loss_dict = dict() + for key, value in loss_wavegrad_dict.items(): + if isinstance(value, int): + loss_dict[key] = value + else: + loss_dict[key] = value.item() + + # epoch/step timing + step_time = time.time() - start_time + epoch_time += step_time + + # get current learning rates + current_lr = list(optimizer.param_groups)[0]['lr'] + + # update avg stats + update_train_values = dict() + for key, value in loss_dict.items(): + update_train_values['avg_' + key] = value + update_train_values['avg_loader_time'] = loader_time + update_train_values['avg_step_time'] = step_time + keep_avg.update_values(update_train_values) + + # print training stats + if global_step % c.print_step == 0: + log_dict = { + 'step_time': [step_time, 2], + 'loader_time': [loader_time, 4], + "current_lr": current_lr, + "grad_norm": grad_norm + } + c_logger.print_train_step(batch_n_iter, num_iter, global_step, + log_dict, loss_dict, keep_avg.avg_values) + + if args.rank == 0: + # plot step stats + if global_step % 10 == 0: + iter_stats = { + "lr": current_lr, + "grad_norm": grad_norm, + "step_time": step_time + } + iter_stats.update(loss_dict) + tb_logger.tb_train_iter_stats(global_step, iter_stats) + + # save checkpoint + if global_step % c.save_step == 0: + if c.checkpoint: + # save model + save_checkpoint(model, + optimizer, + scheduler, + None, + None, + None, + global_step, + epoch, + OUT_PATH, + model_losses=loss_dict) + + # compute spectrograms + figures = plot_results(y_hat[0], y[0], ap, global_step, 'train') + tb_logger.tb_train_figures(global_step, figures) + + # Sample audio + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + tb_logger.tb_train_audios(global_step, + {'train/audio': sample_voice}, + c.audio["sample_rate"]) + end_time = time.time() + + # print epoch stats + c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) + + # Plot Training Epoch Stats + epoch_stats = {"epoch_time": epoch_time} + epoch_stats.update(keep_avg.avg_values) + if args.rank == 0: + tb_logger.tb_train_epoch_stats(global_step, epoch_stats) + # TODO: plot model stats + # if c.tb_model_param_stats: + # tb_logger.tb_model_weights(model, global_step) + return keep_avg.avg_values, global_step + + +@torch.no_grad() +def evaluate(model, criterion, ap, global_step, epoch): + data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) + model.eval() + epoch_time = 0 + keep_avg = KeepAverage() + end_time = time.time() + c_logger.print_eval_start() + for num_iter, data in enumerate(data_loader): + start_time = time.time() + + # format data + m, y = format_data(data) + loader_time = time.time() - end_time + + global_step += 1 + + # compute noisy input + if hasattr(model, 'module'): + y_noisy, noise_scale = model.module.compute_noisy_x(y) + else: + y_noisy, noise_scale = model.compute_noisy_x(y) + + + # forward pass + y_hat = model(y_noisy, m, noise_scale) + + # compute losses + loss = criterion(y_noisy, y_hat) + loss_wavegrad_dict = {'wavegrad_loss':loss} + + + loss_dict = dict() + for key, value in loss_wavegrad_dict.items(): + if isinstance(value, (int, float)): + loss_dict[key] = value + else: + loss_dict[key] = value.item() + + step_time = time.time() - start_time + epoch_time += step_time + + # update avg stats + update_eval_values = dict() + for key, value in loss_dict.items(): + update_eval_values['avg_' + key] = value + update_eval_values['avg_loader_time'] = loader_time + update_eval_values['avg_step_time'] = step_time + keep_avg.update_values(update_eval_values) + + # print eval stats + if c.print_eval: + c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) + + if args.rank == 0: + # compute spectrograms + figures = plot_results(y_hat, y, ap, global_step, 'eval') + tb_logger.tb_eval_figures(global_step, figures) + + # Sample audio + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice}, + c.audio["sample_rate"]) + + tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) + + return keep_avg.avg_values + + +# FIXME: move args definition/parsing inside of main? +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global train_data, eval_data + print(f" > Loading wavs from: {c.data_path}") + if c.feature_path is not None: + print(f" > Loading features from: {c.feature_path}") + eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) + else: + eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) + + # setup audio processor + ap = AudioProcessor(**c.audio) + + # DISTRUBUTED + if num_gpus > 1: + init_distributed(args.rank, num_gpus, args.group_id, + c.distributed["backend"], c.distributed["url"]) + + # setup models + model = setup_generator(c) + + # setup optimizers + optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0) + + # DISTRIBUTED + if c.apex_amp_level: + # pylint: disable=import-outside-toplevel + from apex import amp + from apex.parallel import DistributedDataParallel as DDP + model.cuda() + model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level) + else: + amp = None + + # schedulers + scheduler = None + if 'lr_scheduler' in c: + scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler) + scheduler = scheduler(optimizer, **c.lr_scheduler_params) + + # setup criterion + criterion = torch.nn.L1Loss().cuda() + + if args.restore_path: + checkpoint = torch.load(args.restore_path, map_location='cpu') + try: + print(" > Restoring Model...") + model.load_state_dict(checkpoint['model']) + print(" > Restoring Optimizer...") + optimizer.load_state_dict(checkpoint['optimizer']) + if 'scheduler' in checkpoint: + print(" > Restoring LR Scheduler...") + scheduler.load_state_dict(checkpoint['scheduler']) + # NOTE: Not sure if necessary + scheduler.optimizer = optimizer + except RuntimeError: + # retore only matching layers. + print(" > Partial model initialization...") + model_dict = model.state_dict() + model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model.load_state_dict(model_dict) + del model_dict + + # DISTRUBUTED + if amp and 'amp' in checkpoint: + amp.load_state_dict(checkpoint['amp']) + + # reset lr if not countinuining training. + for group in optimizer.param_groups: + group['lr'] = c.lr + + print(" > Model restored from step %d" % checkpoint['step'], + flush=True) + args.restore_step = checkpoint['step'] + else: + args.restore_step = 0 + + if use_cuda: + model.cuda() + criterion.cuda() + + # DISTRUBUTED + if num_gpus > 1: + model = DDP(model) + + num_params = count_parameters(model) + print(" > WaveGrad has {} parameters".format(num_params), flush=True) + + if 'best_loss' not in locals(): + best_loss = float('inf') + + global_step = args.restore_step + for epoch in range(0, c.epochs): + c_logger.print_epoch_start(epoch, c.epochs) + _, global_step = train(model, criterion, optimizer, + scheduler, ap, global_step, + epoch, amp) + eval_avg_loss_dict = evaluate(model, criterion, ap, + global_step, epoch) + c_logger.print_epoch_end(epoch, eval_avg_loss_dict) + target_loss = eval_avg_loss_dict[c.target_loss] + best_loss = save_best_model(target_loss, + best_loss, + model, + optimizer, + scheduler, + None, + None, + None, + global_step, + epoch, + OUT_PATH, + model_losses=eval_avg_loss_dict, + amp_state_dict=amp.state_dict() if amp else None) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--continue_path', + type=str, + help= + 'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', + default='', + required='--config_path' not in sys.argv) + parser.add_argument( + '--restore_path', + type=str, + help='Model file to be restored. Use to finetune a model.', + default='') + parser.add_argument('--config_path', + type=str, + help='Path to config file for training.', + required='--continue_path' not in sys.argv) + parser.add_argument('--debug', + type=bool, + default=False, + help='Do not verify commit integrity to run training.') + + # DISTRUBUTED + parser.add_argument( + '--rank', + type=int, + default=0, + help='DISTRIBUTED: process rank for distributed training.') + parser.add_argument('--group_id', + type=str, + default="", + help='DISTRIBUTED: process group id.') + args = parser.parse_args() + + if args.continue_path != '': + args.output_path = args.continue_path + args.config_path = os.path.join(args.continue_path, 'config.json') + list_of_files = glob.glob( + args.continue_path + + "/*.pth.tar") # * means all if need specific format then *.csv + latest_model_file = max(list_of_files, key=os.path.getctime) + args.restore_path = latest_model_file + print(f" > Training continues for {args.restore_path}") + + # setup output paths and read configs + c = load_config(args.config_path) + # check_config(c) + _ = os.path.dirname(os.path.realpath(__file__)) + + # DISTRIBUTED + if c.apex_amp_level: + print(" > apex AMP level: ", c.apex_amp_level) + + OUT_PATH = args.continue_path + if args.continue_path == '': + OUT_PATH = create_experiment_folder(c.output_path, c.run_name, + args.debug) + + AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') + + c_logger = ConsoleLogger() + + if args.rank == 0: + os.makedirs(AUDIO_PATH, exist_ok=True) + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + copy_config_file(args.config_path, + os.path.join(OUT_PATH, 'config.json'), new_fields) + os.chmod(AUDIO_PATH, 0o775) + os.chmod(OUT_PATH, 0o775) + + LOG_DIR = OUT_PATH + tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER') + + # write model desc to tensorboard + tb_logger.tb_add_text('model-description', c['run_description'], 0) + + try: + main(args) + except KeyboardInterrupt: + remove_experiment_folder(OUT_PATH) + try: + sys.exit(0) + except SystemExit: + os._exit(0) # pylint: disable=protected-access + except Exception: # pylint: disable=broad-except + remove_experiment_folder(OUT_PATH) + traceback.print_exc() + sys.exit(1) diff --git a/TTS/vocoder/configs/wavegrad_libritts.json b/TTS/vocoder/configs/wavegrad_libritts.json new file mode 100644 index 00000000..79672c71 --- /dev/null +++ b/TTS/vocoder/configs/wavegrad_libritts.json @@ -0,0 +1,103 @@ +{ + "run_name": "wavegrad-libritts", + "run_description": "wavegrad libritts", + + "audio":{ + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + + // Audio processing parameters + "sample_rate": 24000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "ref_level_db": 0, // reference level db, theoretically 20db is the sound of air. + + // Silence trimming + "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60, // threshold for timming silence. Set this according to your dataset. + + // MelSpectrogram parameters + "num_mels": 80, // size of the mel spec frame. + "mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!! + "spec_gain": 1.0, // scaler value appplied after log transform of spectrogram. + + // Normalization parameters + "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. + "min_level_db": -100, // lower bound for normalization + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "stats_path": "/home/erogol/Data/libritts/LibriTTS/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored + }, + + // DISTRIBUTED TRAINING + "apex_amp_level": "O1", // amp optimization level. "O1" is currentl supported. + "distributed":{ + "backend": "nccl", + "url": "tcp:\/\/localhost:54321" + }, + + "target_loss": "avg_wavegrad_loss", // loss value to pick the best model to save after each epoch + + // MODEL PARAMETERS + "generator_model": "wavegrad", + "model_params":{ + "x_conv_channels":32, + "c_conv_channels":768, + "ublock_out_channels": [768, 512, 512, 256, 128], + "dblock_out_channels": [128, 128, 256, 512], + "upsample_factors": [4, 4, 4, 2, 2], + "upsample_dilations": [ + [1, 2, 1, 2], + [1, 2, 1, 2], + [1, 2, 4, 8], + [1, 2, 4, 8], + [1, 2, 4, 8]] + }, + + // DATASET + "data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/", // root data path. It finds all wav files recursively from there. + "feature_path": null, // if you use precomputed features + "seq_len": 6144, // 24 * hop_length + "pad_short": 2000, // additional padding for short wavs + "conv_pad": 0, // additional padding against convolutions applied to spectrograms + "use_noise_augment": false, // add noise to the audio signal for augmentation + "use_cache": true, // use in memory cache to keep the computed features. This might cause OOM. + + "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. + + // TRAINING + "batch_size": 64, // Batch size for training. + + // VALIDATION + "run_eval": true, // enable/disable evaluation run + + // OPTIMIZER + "epochs": 10000, // total number of epochs to train. + "clip_grad": 1, // Generator gradient clipping threshold. Apply gradient clipping if > 0 + "lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate + "lr_scheduler_params": { + "gamma": 0.5, + "milestones": [100000, 200000, 300000, 400000, 500000, 600000] + }, + "lr": 1e-4, // Initial learning rate. If Noam decay is active, maximum learning rate. + + // TENSORBOARD and LOGGING + "print_step": 25, // Number of steps to log traning on console. + "print_eval": false, // If True, it prints loss values for each step in eval run. + "save_step": 10000, // Number of training steps expected to plot training stats on TB and save model checkpoints. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + + // DATA LOADING + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 4, // number of evaluation data loader processes. + "eval_split_size": 10, + + // PATHS + "output_path": "/home/erogol/Models/LJSpeech/" +} + diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py new file mode 100644 index 00000000..4a70c252 --- /dev/null +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -0,0 +1,113 @@ +import os +import glob +import torch +import random +import numpy as np +from torch.utils.data import Dataset +from multiprocessing import Manager + + +class WaveGradDataset(Dataset): + """ + WaveGrad Dataset searchs for all the wav files under root path + and converts them to acoustic features on the fly and returns + random segments of (audio, feature) couples. + """ + def __init__(self, + ap, + items, + seq_len, + hop_len, + pad_short, + conv_pad=2, + is_training=True, + return_segments=True, + use_noise_augment=False, + use_cache=False, + verbose=False): + + self.ap = ap + self.item_list = items + self.compute_feat = not isinstance(items[0], (tuple, list)) + self.seq_len = seq_len + self.hop_len = hop_len + self.pad_short = pad_short + self.conv_pad = conv_pad + self.is_training = is_training + self.return_segments = return_segments + self.use_cache = use_cache + self.use_noise_augment = use_noise_augment + self.verbose = verbose + + assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." + self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) + + # cache acoustic features + if use_cache: + self.create_feature_cache() + + def create_feature_cache(self): + self.manager = Manager() + self.cache = self.manager.list() + self.cache += [None for _ in range(len(self.item_list))] + + @staticmethod + def find_wav_files(path): + return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True) + + def __len__(self): + return len(self.item_list) + + def __getitem__(self, idx): + item = self.load_item(idx) + return item + + def load_item(self, idx): + """ load (audio, feat) couple """ + if self.compute_feat: + # compute features from wav + wavpath = self.item_list[idx] + # print(wavpath) + + if self.use_cache and self.cache[idx] is not None: + audio, mel = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + + if len(audio) < self.seq_len + self.pad_short: + audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \ + mode='constant', constant_values=0.0) + + mel = self.ap.melspectrogram(audio) + else: + + # load precomputed features + wavpath, feat_path = self.item_list[idx] + + if self.use_cache and self.cache[idx] is not None: + audio, mel = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + mel = np.load(feat_path) + + # correct the audio length wrt padding applied in stft + audio = np.pad(audio, (0, self.hop_len), mode="edge") + audio = audio[:mel.shape[-1] * self.hop_len] + assert mel.shape[-1] * self.hop_len == audio.shape[-1], f' [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}' + + audio = torch.from_numpy(audio).float().unsqueeze(0) + mel = torch.from_numpy(mel).float().squeeze(0) + + if self.return_segments: + max_mel_start = mel.shape[1] - self.feat_frame_len + mel_start = random.randint(0, max_mel_start) + mel_end = mel_start + self.feat_frame_len + mel = mel[:, mel_start:mel_end] + + audio_start = mel_start * self.hop_len + audio = audio[:, audio_start:audio_start + + self.seq_len] + + if self.use_noise_augment and self.is_training and self.return_segments: + audio = audio + (1 / 32768) * torch.randn_like(audio) + return (mel, audio) diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py new file mode 100644 index 00000000..69bca0a8 --- /dev/null +++ b/TTS/vocoder/layers/wavegrad.py @@ -0,0 +1,150 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + + +class NoiseLevelEncoding(nn.Module): + """Noise level encoding applying same + encoding vector to all time steps. It is + different than the original implementation.""" + def __init__(self, n_channels): + super().__init__() + self.n_channels = n_channels + self.length = n_channels // 2 + assert n_channels % 2 == 0 + + enc = self.init_encoding(self.length) + self.register_buffer('enc', enc) + + def forward(self, x, noise_level): + """ + Shapes: + x: B x C x T + noise_level: B + """ + return (x + self.encoding(noise_level)[:, :, None]) + + def init_encoding(self, length): + div_by = torch.arange(length) / length + enc = torch.exp(-math.log(1e4) * div_by.unsqueeze(0)) + return enc + + def encoding(self, noise_level): + encoding = noise_level.unsqueeze(1) * self.enc + encoding = torch.cat( + [torch.sin(encoding), torch.cos(encoding)], dim=-1) + return encoding + + +class FiLM(nn.Module): + """Feature-wise Linear Modulation. It combines information from + both noisy waveform and input mel-spectrogram. The FiLM module + produces both scale and bias vectors given inputs, which are + used in a UBlock for feature-wise affine transformation.""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.encoding = NoiseLevelEncoding(in_channels) + self.conv_in = nn.Conv1d(in_channels, in_channels, 3, padding=1) + self.conv_out = nn.Conv1d(in_channels, out_channels * 2, 3, padding=1) + self._init_parameters() + + def _init_parameters(self): + nn.init.orthogonal_(self.conv_in.weight) + nn.init.orthogonal_(self.conv_out.weight) + + def forward(self, x, noise_scale): + x = self.conv_in(x) + x = F.leaky_relu(x, 0.2) + x = self.encoding(x, noise_scale) + shift, scale = torch.chunk(self.conv_out(x), 2, dim=1) + return shift, scale + + +@torch.jit.script +def shif_and_scale(x, scale, shift): + o = shift + scale * x + return o + + +class UBlock(nn.Module): + def __init__(self, in_channels, hid_channels, upsample_factor, dilations): + super().__init__() + assert len(dilations) == 4 + + self.upsample_factor = upsample_factor + self.shortcut_conv = nn.Conv1d(in_channels, hid_channels, 1) + self.main_block1 = nn.ModuleList([ + nn.Conv1d(in_channels, + hid_channels, + 3, + dilation=dilations[0], + padding=dilations[0]), + nn.Conv1d(hid_channels, + hid_channels, + 3, + dilation=dilations[1], + padding=dilations[1]) + ]) + self.main_block2 = nn.ModuleList([ + nn.Conv1d(hid_channels, + hid_channels, + 3, + dilation=dilations[2], + padding=dilations[2]), + nn.Conv1d(hid_channels, + hid_channels, + 3, + dilation=dilations[3], + padding=dilations[3]) + ]) + + def forward(self, x, shift, scale): + upsample_size = x.shape[-1] * self.upsample_factor + x = F.interpolate(x, size=upsample_size) + res = self.shortcut_conv(x) + + o = F.leaky_relu(x, 0.2) + o = self.main_block1[0](o) + o = shif_and_scale(o, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.main_block1[1](o) + + o = o + res + res = o + + o = shif_and_scale(o, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.main_block2[0](o) + o = shif_and_scale(o, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.main_block2[1](o) + + o = o + res + return o + + +class DBlock(nn.Module): + def __init__(self, in_channels, hid_channels, downsample_factor): + super().__init__() + self.downsample_factor = downsample_factor + self.res_conv = nn.Conv1d(in_channels, hid_channels, 1) + self.main_convs = nn.ModuleList([ + nn.Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1), + nn.Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2), + nn.Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4), + ]) + + def forward(self, x): + size = x.shape[-1] // self.downsample_factor + + res = self.res_conv(x) + res = F.interpolate(res, size=size) + + o = F.interpolate(x, size=size) + for layer in self.main_convs: + o = F.leaky_relu(o, 0.2) + o = layer(o) + + return o + res diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py new file mode 100644 index 00000000..6405bea8 --- /dev/null +++ b/TTS/vocoder/models/wavegrad.py @@ -0,0 +1,131 @@ +import numpy as np +import torch +from torch import nn + +from ..layers.wavegrad import DBlock, FiLM, UBlock + + +class Wavegrad(nn.Module): + # pylint: disable=dangerous-default-value + def __init__(self, + in_channels=80, + out_channels=1, + x_conv_channels=32, + c_conv_channels=768, + dblock_out_channels=[128, 128, 256, 512], + ublock_out_channels=[512, 512, 256, 128, 128], + upsample_factors=[5, 5, 3, 2, 2], + upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], + [1, 2, 4, 8], [1, 2, 4, 8]]): + super().__init__() + + assert len(upsample_factors) == len(upsample_dilations) + assert len(upsample_factors) == len(ublock_out_channels) + + # inference time noise schedule params + self.S = 1000 + beta, alpha, alpha_cum, noise_level = self._setup_noise_level() + self.register_buffer('beta', beta) + self.register_buffer('alpha', alpha) + self.register_buffer('alpha_cum', alpha_cum) + self.register_buffer('noise_level', noise_level) + + # setup up-down sampling parameters + self.hop_length = np.prod(upsample_factors) + self.upsample_factors = upsample_factors + self.downsample_factors = upsample_factors[::-1][:-1] + + ### define DBlocks, FiLM layers ### + self.dblocks = nn.ModuleList([ + nn.Conv1d(out_channels, x_conv_channels, 5, padding=2), + ]) + ic = x_conv_channels + self.films = nn.ModuleList([]) + for oc, df in zip(dblock_out_channels, self.downsample_factors): + # print('dblock(', ic, ', ', oc, ', ', df, ")") + layer = DBlock(ic, oc, df) + self.dblocks.append(layer) + + # print('film(', ic, ', ', oc,")") + layer = FiLM(ic, oc) + self.films.append(layer) + ic = oc + # last FiLM block + # print('film(', ic, ', ', dblock_out_channels[-1],")") + self.films.append(FiLM(ic, dblock_out_channels[-1])) + + ### define UBlocks ### + self.c_conv = nn.Conv1d(in_channels, c_conv_channels, 3, padding=1) + self.ublocks = nn.ModuleList([]) + ic = c_conv_channels + for idx, (oc, uf) in enumerate(zip(ublock_out_channels, self.upsample_factors)): + # print('ublock(', ic, ', ', oc, ', ', uf, ")") + layer = UBlock(ic, oc, uf, upsample_dilations[idx]) + self.ublocks.append(layer) + ic = oc + + # define last layer + # print(ic, 'last_conv--', out_channels) + self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1) + + def _setup_noise_level(self, noise_schedule=None): + """compute noise schedule parameters""" + if noise_schedule is None: + beta = np.linspace(1e-6, 0.01, self.S) + else: + beta = noise_schedule + alpha = 1 - beta + alpha_cum = np.cumprod(alpha) + noise_level = np.concatenate([[1.0], alpha_cum ** 0.5], axis=0) + + beta = torch.from_numpy(beta) + alpha = torch.from_numpy(alpha) + alpha_cum = torch.from_numpy(alpha_cum) + noise_level = torch.from_numpy(noise_level.astype(np.float32)) + return beta, alpha, alpha_cum, noise_level + + def compute_noisy_x(self, x): + B = x.shape[0] + if len(x.shape) == 3: + x = x.squeeze(1) + s = torch.randint(1, self.S + 1, [B]).to(x).long() + l_a, l_b = self.noise_level[s-1], self.noise_level[s] + noise_scale = l_a + torch.rand(B).to(x) * (l_b - l_a) + noise_scale = noise_scale.unsqueeze(1) + noise = torch.randn_like(x) + noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise + return noisy_x.unsqueeze(1), noise_scale[:, 0] + + def forward(self, x, c, noise_scale): + assert len(c.shape) == 3 # B, C, T + assert len(x.shape) == 3 # B, 1, T + o = x + shift_and_scales = [] + for film, dblock in zip(self.films, self.dblocks): + o = dblock(o) + shift_and_scales.append(film(o, noise_scale)) + + o = self.c_conv(c) + for ublock, (film_shift, film_scale) in zip(self.ublocks, + reversed(shift_and_scales)): + o = ublock(o, film_shift, film_scale) + o = self.last_conv(o) + return o + + def inference(self, c): + with torch.no_grad(): + x = torch.randn(c.shape[0], self.hop_length * c.shape[-1]).to(c) + noise_scale = torch.from_numpy( + self.alpha_cum**0.5).float().unsqueeze(1).to(c) + for n in range(len(self.alpha) - 1, -1, -1): + c1 = 1 / self.alpha[n]**0.5 + c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5 + x = c1 * (x - + c2 * self.forward(x, c, noise_scale[n]).squeeze(1)) + if n > 0: + noise = torch.randn_like(x) + sigma = ((1.0 - self.alpha_cum[n - 1]) / + (1.0 - self.alpha_cum[n]) * self.beta[n])**0.5 + x += sigma * noise + x = torch.clamp(x, -1.0, 1.0) + return x