From 420820caf4f0ad52b232408ec564b8bbb15afcab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 18 Jun 2021 13:37:23 +0200 Subject: [PATCH] Update vocoder models --- TTS/vocoder/models/gan.py | 246 +++++++++++++++++++++++++ TTS/vocoder/models/wavegrad.py | 190 ++++++++++++++++--- TTS/vocoder/models/wavernn.py | 322 ++++++++++++++++++++++++--------- 3 files changed, 639 insertions(+), 119 deletions(-) create mode 100644 TTS/vocoder/models/gan.py diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py new file mode 100644 index 00000000..58d6532e --- /dev/null +++ b/TTS/vocoder/models/gan.py @@ -0,0 +1,246 @@ +from inspect import signature +from typing import Dict, List, Tuple + +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from TTS.utils.audio import AudioProcessor +from TTS.utils.trainer_utils import get_optimizer, get_scheduler +from TTS.vocoder.datasets.gan_dataset import GANDataset +from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss +from TTS.vocoder.models import setup_discriminator, setup_generator +from TTS.vocoder.models.base_vocoder import BaseVocoder +from TTS.vocoder.utils.generic_utils import plot_results + + +class GAN(BaseVocoder): + def __init__(self, config: Coqpit): + """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer. + It also helps mixing and matching different generator and disciminator networks easily. + + Args: + config (Coqpit): Model configuration. + + Examples: + Initializing the GAN model with HifiGAN generator and discriminator. + >>> from TTS.vocoder.configs import HifiganConfig + >>> config = HifiganConfig() + >>> model = GAN(config) + """ + super().__init__() + self.config = config + self.model_g = setup_generator(config) + self.model_d = setup_discriminator(config) + self.train_disc = False # if False, train only the generator. + self.y_hat_g = None # the last generator prediction to be passed onto the discriminator + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model_g.forward(x) + + def inference(self, x: torch.Tensor) -> torch.Tensor: + return self.model_g.inference(x) + + def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]: + outputs = None + loss_dict = None + + x = batch["input"] + y = batch["waveform"] + + if optimizer_idx not in [0, 1]: + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + if optimizer_idx == 0: + # GENERATOR + # generator pass + y_hat = self.model_g(x)[:, :, : y.size(2)] + self.y_hat_g = y_hat # save for discriminator + y_hat_sub = None + y_sub = None + + # PQMF formatting + if y_hat.shape[1] > 1: + y_hat_sub = y_hat + y_hat = self.model_g.pqmf_synthesis(y_hat) + self.y_hat_g = y_hat # save for discriminator + y_sub = self.model_g.pqmf_analysis(y) + + scores_fake, feats_fake, feats_real = None, None, None + if self.train_disc: + + if len(signature(self.model_d.forward).parameters) == 2: + D_out_fake = self.model_d(y_hat, x) + else: + D_out_fake = self.model_d(y_hat) + D_out_real = None + + if self.config.use_feat_match_loss: + with torch.no_grad(): + D_out_real = self.model_d(y) + + # format D outputs + if isinstance(D_out_fake, tuple): + scores_fake, feats_fake = D_out_fake + if D_out_real is None: + feats_real = None + else: + _, feats_real = D_out_real + else: + scores_fake = D_out_fake + feats_fake, feats_real = None, None + + # compute losses + loss_dict = criterion[optimizer_idx](y_hat, y, scores_fake, feats_fake, feats_real, y_hat_sub, y_sub) + outputs = {"model_outputs": y_hat} + + if optimizer_idx == 1: + # DISCRIMINATOR + if self.train_disc: + # use different samples for G and D trainings + if self.config.diff_samples_for_G_and_D: + x_d = batch["input_disc"] + y_d = batch["waveform_disc"] + # use a different sample than generator + with torch.no_grad(): + y_hat = self.model_g(x_d) + + # PQMF formatting + if y_hat.shape[1] > 1: + y_hat = self.model_g.pqmf_synthesis(y_hat) + else: + # use the same samples as generator + x_d = x.clone() + y_d = y.clone() + y_hat = self.y_hat_g + + # run D with or without cond. features + if len(signature(self.model_d.forward).parameters) == 2: + D_out_fake = self.model_d(y_hat.detach().clone(), x_d) + D_out_real = self.model_d(y_d, x_d) + else: + D_out_fake = self.model_d(y_hat.detach()) + D_out_real = self.model_d(y_d) + + # format D outputs + if isinstance(D_out_fake, tuple): + # self.model_d returns scores and features + scores_fake, feats_fake = D_out_fake + if D_out_real is None: + scores_real, feats_real = None, None + else: + scores_real, feats_real = D_out_real + else: + # model D returns only scores + scores_fake = D_out_fake + scores_real = D_out_real + + # compute losses + loss_dict = criterion[optimizer_idx](scores_fake, scores_real) + outputs = {"model_outputs": y_hat} + + return outputs, loss_dict + + def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + y_hat = outputs[0]["model_outputs"] + y = batch["waveform"] + figures = plot_results(y_hat, y, ap, "train") + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios = {"train/audio": sample_voice} + return figures, audios + + @torch.no_grad() + def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + return self.train_step(batch, criterion, optimizer_idx) + + def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + return self.train_log(ap, batch, outputs) + + def load_checkpoint( + self, + config: Coqpit, + checkpoint_path: str, + eval: bool = False, # pylint: disable=unused-argument, redefined-builtin + ) -> None: + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + # band-aid for older than v0.0.15 GAN models + if "model_disc" in state: + self.model_g.load_checkpoint(config, checkpoint_path, eval) + else: + self.load_state_dict(state["model"]) + if eval: + self.model_d = None + if hasattr(self.model_g, "remove_weight_norm"): + self.model_g.remove_weight_norm() + + def on_train_step_start(self, trainer) -> None: + self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator + + def get_optimizer(self): + optimizer1 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g + ) + optimizer2 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d + ) + return [optimizer1, optimizer2] + + def get_lr(self): + return [self.config.lr_gen, self.config.lr_disc] + + def get_scheduler(self, optimizer): + scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler1, scheduler2] + + @staticmethod + def format_batch(batch): + if isinstance(batch[0], list): + x_G, y_G = batch[0] + x_D, y_D = batch[1] + return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D} + x, y = batch + return {"input": x, "waveform": y} + + def get_data_loader( # pylint: disable=no-self-use + self, + config: Coqpit, + ap: AudioProcessor, + is_eval: True, + data_items: List, + verbose: bool, + num_gpus: int, + ): + dataset = GANDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, + is_training=not is_eval, + return_segments=not is_eval, + use_noise_augment=config.use_noise_augment, + use_cache=config.use_cache, + verbose=verbose, + ) + dataset.shuffle_mapping() + sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=1 if is_eval else config.batch_size, + shuffle=num_gpus == 0, + drop_last=False, + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_criterion(self): + """Return criterions for the optimizers""" + return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)] diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 84dde957..03d5160e 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -1,65 +1,105 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Tuple + import numpy as np import torch +from coqpit import Coqpit from torch import nn from torch.nn.utils import weight_norm +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler -from ..layers.wavegrad import Conv1d, DBlock, FiLM, UBlock +from TTS.model import BaseModel +from TTS.utils.audio import AudioProcessor +from TTS.utils.trainer_utils import get_optimizer, get_scheduler +from TTS.vocoder.datasets import WaveGradDataset +from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock +from TTS.vocoder.utils.generic_utils import plot_results -class Wavegrad(nn.Module): +@dataclass +class WavegradArgs(Coqpit): + in_channels: int = 80 + out_channels: int = 1 + use_weight_norm: bool = False + y_conv_channels: int = 32 + x_conv_channels: int = 768 + dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512]) + ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128]) + upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2]) + upsample_dilations: List[List[int]] = field( + default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]] + ) + + +class Wavegrad(BaseModel): + """🐸 🌊 WaveGrad 🌊 model. + Paper - https://arxiv.org/abs/2009.00713 + + Examples: + Initializing the model. + + >>> from TTS.vocoder.configs import WavegradConfig + >>> config = WavegradConfig() + >>> model = Wavegrad(config) + + Paper Abstract: + This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the + data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts + from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned + on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting + the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in + terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations. + Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive + baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations. + Audio samples are available at this https URL. + """ + # pylint: disable=dangerous-default-value - def __init__( - self, - in_channels=80, - out_channels=1, - use_weight_norm=False, - y_conv_channels=32, - x_conv_channels=768, - dblock_out_channels=[128, 128, 256, 512], - ublock_out_channels=[512, 512, 256, 128, 128], - upsample_factors=[5, 5, 3, 2, 2], - upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]], - ): + def __init__(self, config: Coqpit): super().__init__() - - self.use_weight_norm = use_weight_norm - self.hop_len = np.prod(upsample_factors) + self.config = config + self.use_weight_norm = config.model_params.use_weight_norm + self.hop_len = np.prod(config.model_params.upsample_factors) self.noise_level = None self.num_steps = None self.beta = None self.alpha = None self.alpha_hat = None - self.noise_level = None self.c1 = None self.c2 = None self.sigma = None # dblocks - self.y_conv = Conv1d(1, y_conv_channels, 5, padding=2) + self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2) self.dblocks = nn.ModuleList([]) - ic = y_conv_channels - for oc, df in zip(dblock_out_channels, reversed(upsample_factors)): + ic = config.model_params.y_conv_channels + for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)): self.dblocks.append(DBlock(ic, oc, df)) ic = oc # film self.film = nn.ModuleList([]) - ic = y_conv_channels - for oc in reversed(ublock_out_channels): + ic = config.model_params.y_conv_channels + for oc in reversed(config.model_params.ublock_out_channels): self.film.append(FiLM(ic, oc)) ic = oc - # ublocks + # ublocksn self.ublocks = nn.ModuleList([]) - ic = x_conv_channels - for oc, uf, ud in zip(ublock_out_channels, upsample_factors, upsample_dilations): + ic = config.model_params.x_conv_channels + for oc, uf, ud in zip( + config.model_params.ublock_out_channels, + config.model_params.upsample_factors, + config.model_params.upsample_dilations, + ): self.ublocks.append(UBlock(ic, oc, uf, ud)) ic = oc - self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1) - self.out_conv = Conv1d(oc, out_channels, 3, padding=1) + self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1) + self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1) - if use_weight_norm: + if config.model_params.use_weight_norm: self.apply_weight_norm() def forward(self, x, spectrogram, noise_scale): @@ -180,7 +220,7 @@ class Wavegrad(nn.Module): if eval: self.eval() assert not self.training - if self.use_weight_norm: + if self.config.model_params.use_weight_norm: self.remove_weight_norm() betas = np.linspace( config["test_noise_schedule"]["min_val"], @@ -195,3 +235,93 @@ class Wavegrad(nn.Module): config["train_noise_schedule"]["num_steps"], ) self.compute_noise_level(betas) + + def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + # format data + x = batch["input"] + y = batch["waveform"] + + # set noise scale + noise, x_noisy, noise_scale = self.compute_y_n(y) + + # forward pass + noise_hat = self.forward(x_noisy, x, noise_scale) + + # compute losses + loss = criterion(noise, noise_hat) + return {"model_output": noise_hat}, {"loss": loss} + + def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + return None, None + + @torch.no_grad() + def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: + return self.train_step(batch, criterion) + + def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + return None, None + + def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument + # setup noise schedule and inference + noise_schedule = self.config["test_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + self.compute_noise_level(betas) + for sample in samples: + x = sample["input"] + y = sample["waveform"] + # compute voice + y_pred = self.inference(x) + # compute spectrograms + figures = plot_results(y_pred, y, ap, "test") + # Sample audio + sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy() + return figures, {"test/audio": sample_voice} + + def get_optimizer(self): + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer): + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) + + def get_criterion(self): + return torch.nn.L1Loss() + + @staticmethod + def format_batch(batch: Dict) -> Dict: + # return a whole audio segment + m, y = batch[0], batch[1] + y = y.unsqueeze(1) + return {"input": m, "waveform": y} + + def get_data_loader( + self, config: Coqpit, ap: AudioProcessor, is_eval: True, data_items: List, verbose: bool, num_gpus: int + ): + dataset = WaveGradDataset( + ap=ap, + items=data_items, + seq_len=self.config.seq_len, + hop_len=ap.hop_length, + pad_short=self.config.pad_short, + conv_pad=self.config.conv_pad, + is_training=not is_eval, + return_segments=True, + use_noise_augment=False, + use_cache=config.use_cache, + verbose=verbose, + ) + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=self.config.batch_size, + shuffle=num_gpus <= 1, + drop_last=False, + sampler=sampler, + num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers, + pin_memory=False, + ) + return loader + + def on_epoch_start(self, trainer): # pylint: disable=unused-argument + noise_schedule = self.config["train_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + self.compute_noise_level(betas) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 04040931..a5d89d5a 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -1,13 +1,21 @@ import sys import time +from dataclasses import dataclass, field +from typing import Dict, List, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from coqpit import Coqpit +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler -# fix this -from TTS.utils.audio import AudioProcessor as ap +from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset +from TTS.vocoder.layers.losses import WaveRNNLoss +from TTS.vocoder.models.base_vocoder import BaseVocoder from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian @@ -135,89 +143,145 @@ class Upsample(nn.Module): return m.transpose(1, 2), aux -class WaveRNN(nn.Module): - def __init__( - self, - rnn_dims, - fc_dims, - mode, - mulaw, - pad, - use_aux_net, - use_upsample_net, - upsample_factors, - feat_dims, - compute_dims, - res_out_dims, - num_res_blocks, - hop_length, - sample_rate, - ): +@dataclass +class WavernnArgs(Coqpit): + """🐸 WaveRNN model arguments. + + rnn_dims (int): + Number of hidden channels in RNN layers. Defaults to 512. + fc_dims (int): + Number of hidden channels in fully-conntected layers. Defaults to 512. + compute_dims (int): + Number of hidden channels in the feature ResNet. Defaults to 128. + res_out_dim (int): + Number of hidden channels in the feature ResNet output. Defaults to 128. + num_res_blocks (int): + Number of residual blocks in the ResNet. Defaults to 10. + use_aux_net (bool): + enable/disable the feature ResNet. Defaults to True. + use_upsample_net (bool): + enable/ disable the upsampling networl. If False, basic upsampling is used. Defaults to True. + upsample_factors (list): + Upsampling factors. The multiply of the values must match the `hop_length`. Defaults to ```[4, 8, 8]```. + mode (str): + Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single + Gaussian Distribution and `bits` for quantized bits as the model's output. + mulaw (bool): + enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults + to `True`. + pad (int): + Padding applied to the input feature frames against the convolution layers of the feature network. + Defaults to 2. + """ + + rnn_dims: int = 512 + fc_dims: int = 512 + compute_dims: int = 128 + res_out_dims: int = 128 + num_res_blocks: int = 10 + use_aux_net: bool = True + use_upsample_net: bool = True + upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8]) + mode: str = "mold" # mold [string], gauss [string], bits [int] + mulaw: bool = True # apply mulaw if mode is bits + pad: int = 2 + feat_dims: int = 80 + + +class Wavernn(BaseVocoder): + def __init__(self, config: Coqpit): + """🐸 WaveRNN model. + Original paper - https://arxiv.org/abs/1802.08435 + Official implementation - https://github.com/fatchord/WaveRNN + + Args: + config (Coqpit): [description] + + Raises: + RuntimeError: [description] + + Examples: + >>> from TTS.vocoder.configs import WavernnConfig + >>> config = WavernnConfig() + >>> model = Wavernn(config) + + Paper Abstract: + Sequential models achieve state-of-the-art results in audio, visual and textual domains with respect to + both estimating the data distribution and generating high-quality samples. Efficient sampling for this + class of models has however remained an elusive problem. With a focus on text-to-speech synthesis, we + describe a set of general techniques for reducing sampling time while maintaining high output quality. + We first describe a single-layer recurrent neural network, the WaveRNN, with a dual softmax layer that + matches the quality of the state-of-the-art WaveNet model. The compact form of the network makes it + possible to generate 24kHz 16-bit audio 4x faster than real time on a GPU. Second, we apply a weight + pruning technique to reduce the number of weights in the WaveRNN. We find that, for a constant number of + parameters, large sparse networks perform better than small dense networks and this relationship holds for + sparsity levels beyond 96%. The small number of weights in a Sparse WaveRNN makes it possible to sample + high-fidelity audio on a mobile CPU in real time. Finally, we propose a new generation scheme based on + subscaling that folds a long sequence into a batch of shorter sequences and allows one to generate multiple + samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an + orthogonal method for increasing sampling efficiency. + """ super().__init__() - self.mode = mode - self.mulaw = mulaw - self.pad = pad - self.use_upsample_net = use_upsample_net - self.use_aux_net = use_aux_net - if isinstance(self.mode, int): - self.n_classes = 2 ** self.mode - elif self.mode == "mold": + + self.args = config.model_params + self.config = config + + if isinstance(self.args.mode, int): + self.n_classes = 2 ** self.args.mode + elif self.args.mode == "mold": self.n_classes = 3 * 10 - elif self.mode == "gauss": + elif self.args.mode == "gauss": self.n_classes = 2 else: - raise RuntimeError("Unknown model mode value - ", self.mode) + raise RuntimeError("Unknown model mode value - ", self.args.mode) - self.rnn_dims = rnn_dims - self.aux_dims = res_out_dims // 4 - self.hop_length = hop_length - self.sample_rate = sample_rate + self.aux_dims = self.args.res_out_dims // 4 - if self.use_upsample_net: + if self.args.use_upsample_net: assert ( - np.cumproduct(upsample_factors)[-1] == self.hop_length + np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length ), " [!] upsample scales needs to be equal to hop_length" self.upsample = UpsampleNetwork( - feat_dims, - upsample_factors, - compute_dims, - num_res_blocks, - res_out_dims, - pad, - use_aux_net, + self.args.feat_dims, + self.args.upsample_factors, + self.args.compute_dims, + self.args.num_res_blocks, + self.args.res_out_dims, + self.args.pad, + self.args.use_aux_net, ) else: self.upsample = Upsample( - hop_length, - pad, - num_res_blocks, - feat_dims, - compute_dims, - res_out_dims, - use_aux_net, + config.audio.hop_length, + self.args.pad, + self.args.num_res_blocks, + self.args.feat_dims, + self.args.compute_dims, + self.args.res_out_dims, + self.args.use_aux_net, ) - if self.use_aux_net: - self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims) - self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) - self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True) - self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) - self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) - self.fc3 = nn.Linear(fc_dims, self.n_classes) + if self.args.use_aux_net: + self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims) + self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True) + self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims) + self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims) + self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes) else: - self.I = nn.Linear(feat_dims + 1, rnn_dims) - self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) - self.rnn2 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) - self.fc1 = nn.Linear(rnn_dims, fc_dims) - self.fc2 = nn.Linear(fc_dims, fc_dims) - self.fc3 = nn.Linear(fc_dims, self.n_classes) + self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims) + self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims) + self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims) + self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes) def forward(self, x, mels): bsize = x.size(0) - h1 = torch.zeros(1, bsize, self.rnn_dims).to(x.device) - h2 = torch.zeros(1, bsize, self.rnn_dims).to(x.device) + h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device) + h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device) mels, aux = self.upsample(mels) - if self.use_aux_net: + if self.args.use_aux_net: aux_idx = [self.aux_dims * i for i in range(5)] a1 = aux[:, :, aux_idx[0] : aux_idx[1]] a2 = aux[:, :, aux_idx[1] : aux_idx[2]] @@ -226,7 +290,7 @@ class WaveRNN(nn.Module): x = ( torch.cat([x.unsqueeze(-1), mels, a1], dim=2) - if self.use_aux_net + if self.args.use_aux_net else torch.cat([x.unsqueeze(-1), mels], dim=2) ) x = self.I(x) @@ -236,15 +300,15 @@ class WaveRNN(nn.Module): x = x + res res = x - x = torch.cat([x, a2], dim=2) if self.use_aux_net else x + x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x self.rnn2.flatten_parameters() x, _ = self.rnn2(x, h2) x = x + res - x = torch.cat([x, a3], dim=2) if self.use_aux_net else x + x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x x = F.relu(self.fc1(x)) - x = torch.cat([x, a4], dim=2) if self.use_aux_net else x + x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x x = F.relu(self.fc2(x)) return self.fc3(x) @@ -262,9 +326,9 @@ class WaveRNN(nn.Module): if mels.ndim == 2: mels = mels.unsqueeze(0) - wave_len = (mels.size(-1) - 1) * self.hop_length + wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length - mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both") + mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both") mels, aux = self.upsample(mels.transpose(1, 2)) if batched: @@ -274,11 +338,11 @@ class WaveRNN(nn.Module): b_size, seq_len, _ = mels.size() - h1 = torch.zeros(b_size, self.rnn_dims).type_as(mels) - h2 = torch.zeros(b_size, self.rnn_dims).type_as(mels) + h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels) + h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels) x = torch.zeros(b_size, 1).type_as(mels) - if self.use_aux_net: + if self.args.use_aux_net: d = self.aux_dims aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)] @@ -286,35 +350,35 @@ class WaveRNN(nn.Module): m_t = mels[:, i, :] - if self.use_aux_net: + if self.args.use_aux_net: a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) - x = torch.cat([x, m_t, a1_t], dim=1) if self.use_aux_net else torch.cat([x, m_t], dim=1) + x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1) x = self.I(x) h1 = rnn1(x, h1) x = x + h1 - inp = torch.cat([x, a2_t], dim=1) if self.use_aux_net else x + inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x h2 = rnn2(inp, h2) x = x + h2 - x = torch.cat([x, a3_t], dim=1) if self.use_aux_net else x + x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x x = F.relu(self.fc1(x)) - x = torch.cat([x, a4_t], dim=1) if self.use_aux_net else x + x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x x = F.relu(self.fc2(x)) logits = self.fc3(x) - if self.mode == "mold": + if self.args.mode == "mold": sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) output.append(sample.view(-1)) x = sample.transpose(0, 1).type_as(mels) - elif self.mode == "gauss": + elif self.args.mode == "gauss": sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2)) output.append(sample.view(-1)) x = sample.transpose(0, 1).type_as(mels) - elif isinstance(self.mode, int): + elif isinstance(self.args.mode, int): posterior = F.softmax(logits, dim=1) distrib = torch.distributions.Categorical(posterior) @@ -322,7 +386,7 @@ class WaveRNN(nn.Module): output.append(sample) x = sample.unsqueeze(-1) else: - raise RuntimeError("Unknown model mode value - ", self.mode) + raise RuntimeError("Unknown model mode value - ", self.args.mode) if i % 100 == 0: self.gen_display(i, seq_len, b_size, start) @@ -337,22 +401,22 @@ class WaveRNN(nn.Module): else: output = output[0] - if self.mulaw and isinstance(self.mode, int): - output = ap.mulaw_decode(output, self.mode) + if self.args.mulaw and isinstance(self.args.mode, int): + output = AudioProcessor.mulaw_decode(output, self.args.mode) # Fade-out at the end to avoid signal cutting out suddenly - fade_out = np.linspace(1, 0, 20 * self.hop_length) + fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length) output = output[:wave_len] if wave_len > len(fade_out): - output[-20 * self.hop_length :] *= fade_out + output[-20 * self.config.audio.hop_length :] *= fade_out self.train() return output def gen_display(self, i, seq_len, b_size, start): gen_rate = (i + 1) / (time.time() - start) * b_size / 1000 - realtime_ratio = gen_rate * 1000 / self.sample_rate + realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate stream( "%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ", (i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio), @@ -486,3 +550,83 @@ class WaveRNN(nn.Module): if eval: self.eval() assert not self.training + + def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + mels = batch["input"] + waveform = batch["waveform"] + waveform_coarse = batch["waveform_coarse"] + + y_hat = self.forward(waveform, mels) + if isinstance(self.args.mode, int): + y_hat = y_hat.transpose(1, 2).unsqueeze(-1) + else: + waveform_coarse = waveform_coarse.float() + waveform_coarse = waveform_coarse.unsqueeze(-1) + # compute losses + loss_dict = criterion(y_hat, waveform_coarse) + return {"model_output": y_hat}, loss_dict + + def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + return self.train_step(batch, criterion) + + @torch.no_grad() + def test_run( + self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument + ) -> Tuple[Dict, Dict]: + figures = {} + audios = {} + for idx, sample in enumerate(samples): + x = sample["input"] + y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples) + x_hat = ap.melspectrogram(y_hat) + figures.update( + { + f"test_{idx}/ground_truth": plot_spectrogram(x.T), + f"test_{idx}/prediction": plot_spectrogram(x_hat.T), + } + ) + audios.update({f"test_{idx}/audio", y_hat}) + return figures, audios + + @staticmethod + def format_batch(batch: Dict) -> Dict: + waveform = batch[0] + mels = batch[1] + waveform_coarse = batch[2] + return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse} + + def get_data_loader( # pylint: disable=no-self-use + self, + config: Coqpit, + ap: AudioProcessor, + is_eval: True, + data_items: List, + verbose: bool, + num_gpus: int, + ): + dataset = WaveRNNDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad=config.model_params.pad, + mode=config.model_params.mode, + mulaw=config.model_params.mulaw, + is_training=not is_eval, + verbose=verbose, + ) + sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=1 if is_eval else config.batch_size, + shuffle=num_gpus == 0, + collate_fn=dataset.collate, + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=True, + ) + return loader + + def get_criterion(self): + # define train functions + return WaveRNNLoss(self.args.mode)