Update vocoder models

This commit is contained in:
Eren Gölge 2021-06-18 13:37:23 +02:00
parent 51005cdab4
commit e949e7ad58
3 changed files with 639 additions and 119 deletions

246
TTS/vocoder/models/gan.py Normal file
View File

@ -0,0 +1,246 @@
from inspect import signature
from typing import Dict, List, Tuple
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.models import setup_discriminator, setup_generator
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.generic_utils import plot_results
class GAN(BaseVocoder):
def __init__(self, config: Coqpit):
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
It also helps mixing and matching different generator and disciminator networks easily.
Args:
config (Coqpit): Model configuration.
Examples:
Initializing the GAN model with HifiGAN generator and discriminator.
>>> from TTS.vocoder.configs import HifiganConfig
>>> config = HifiganConfig()
>>> model = GAN(config)
"""
super().__init__()
self.config = config
self.model_g = setup_generator(config)
self.model_d = setup_discriminator(config)
self.train_disc = False # if False, train only the generator.
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model_g.forward(x)
def inference(self, x: torch.Tensor) -> torch.Tensor:
return self.model_g.inference(x)
def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
outputs = None
loss_dict = None
x = batch["input"]
y = batch["waveform"]
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
if optimizer_idx == 0:
# GENERATOR
# generator pass
y_hat = self.model_g(x)[:, :, : y.size(2)]
self.y_hat_g = y_hat # save for discriminator
y_hat_sub = None
y_sub = None
# PQMF formatting
if y_hat.shape[1] > 1:
y_hat_sub = y_hat
y_hat = self.model_g.pqmf_synthesis(y_hat)
self.y_hat_g = y_hat # save for discriminator
y_sub = self.model_g.pqmf_analysis(y)
scores_fake, feats_fake, feats_real = None, None, None
if self.train_disc:
if len(signature(self.model_d.forward).parameters) == 2:
D_out_fake = self.model_d(y_hat, x)
else:
D_out_fake = self.model_d(y_hat)
D_out_real = None
if self.config.use_feat_match_loss:
with torch.no_grad():
D_out_real = self.model_d(y)
# format D outputs
if isinstance(D_out_fake, tuple):
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
feats_real = None
else:
_, feats_real = D_out_real
else:
scores_fake = D_out_fake
feats_fake, feats_real = None, None
# compute losses
loss_dict = criterion[optimizer_idx](y_hat, y, scores_fake, feats_fake, feats_real, y_hat_sub, y_sub)
outputs = {"model_outputs": y_hat}
if optimizer_idx == 1:
# DISCRIMINATOR
if self.train_disc:
# use different samples for G and D trainings
if self.config.diff_samples_for_G_and_D:
x_d = batch["input_disc"]
y_d = batch["waveform_disc"]
# use a different sample than generator
with torch.no_grad():
y_hat = self.model_g(x_d)
# PQMF formatting
if y_hat.shape[1] > 1:
y_hat = self.model_g.pqmf_synthesis(y_hat)
else:
# use the same samples as generator
x_d = x.clone()
y_d = y.clone()
y_hat = self.y_hat_g
# run D with or without cond. features
if len(signature(self.model_d.forward).parameters) == 2:
D_out_fake = self.model_d(y_hat.detach().clone(), x_d)
D_out_real = self.model_d(y_d, x_d)
else:
D_out_fake = self.model_d(y_hat.detach())
D_out_real = self.model_d(y_d)
# format D outputs
if isinstance(D_out_fake, tuple):
# self.model_d returns scores and features
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
scores_real, feats_real = None, None
else:
scores_real, feats_real = D_out_real
else:
# model D returns only scores
scores_fake = D_out_fake
scores_real = D_out_real
# compute losses
loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
outputs = {"model_outputs": y_hat}
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
y_hat = outputs[0]["model_outputs"]
y = batch["waveform"]
figures = plot_results(y_hat, y, ap, "train")
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
audios = {"train/audio": sample_voice}
return figures, audios
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
return self.train_log(ap, batch, outputs)
def load_checkpoint(
self,
config: Coqpit,
checkpoint_path: str,
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
) -> None:
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
# band-aid for older than v0.0.15 GAN models
if "model_disc" in state:
self.model_g.load_checkpoint(config, checkpoint_path, eval)
else:
self.load_state_dict(state["model"])
if eval:
self.model_d = None
if hasattr(self.model_g, "remove_weight_norm"):
self.model_g.remove_weight_norm()
def on_train_step_start(self, trainer) -> None:
self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator
def get_optimizer(self):
optimizer1 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
)
optimizer2 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
)
return [optimizer1, optimizer2]
def get_lr(self):
return [self.config.lr_gen, self.config.lr_disc]
def get_scheduler(self, optimizer):
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler1, scheduler2]
@staticmethod
def format_batch(batch):
if isinstance(batch[0], list):
x_G, y_G = batch[0]
x_D, y_D = batch[1]
return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D}
x, y = batch
return {"input": x, "waveform": y}
def get_data_loader( # pylint: disable=no-self-use
self,
config: Coqpit,
ap: AudioProcessor,
is_eval: True,
data_items: List,
verbose: bool,
num_gpus: int,
):
dataset = GANDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
is_training=not is_eval,
return_segments=not is_eval,
use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=1 if is_eval else config.batch_size,
shuffle=num_gpus == 0,
drop_last=False,
sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
return loader
def get_criterion(self):
"""Return criterions for the optimizers"""
return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)]

View File

@ -1,65 +1,105 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.nn.utils import weight_norm
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from ..layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
from TTS.model import BaseModel
from TTS.utils.audio import AudioProcessor
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
from TTS.vocoder.utils.generic_utils import plot_results
class Wavegrad(nn.Module):
@dataclass
class WavegradArgs(Coqpit):
in_channels: int = 80
out_channels: int = 1
use_weight_norm: bool = False
y_conv_channels: int = 32
x_conv_channels: int = 768
dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512])
ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128])
upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2])
upsample_dilations: List[List[int]] = field(
default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]]
)
class Wavegrad(BaseModel):
"""🐸 🌊 WaveGrad 🌊 model.
Paper - https://arxiv.org/abs/2009.00713
Examples:
Initializing the model.
>>> from TTS.vocoder.configs import WavegradConfig
>>> config = WavegradConfig()
>>> model = Wavegrad(config)
Paper Abstract:
This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the
data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts
from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned
on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting
the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in
terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations.
Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive
baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations.
Audio samples are available at this https URL.
"""
# pylint: disable=dangerous-default-value
def __init__(
self,
in_channels=80,
out_channels=1,
use_weight_norm=False,
y_conv_channels=32,
x_conv_channels=768,
dblock_out_channels=[128, 128, 256, 512],
ublock_out_channels=[512, 512, 256, 128, 128],
upsample_factors=[5, 5, 3, 2, 2],
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]],
):
def __init__(self, config: Coqpit):
super().__init__()
self.use_weight_norm = use_weight_norm
self.hop_len = np.prod(upsample_factors)
self.config = config
self.use_weight_norm = config.model_params.use_weight_norm
self.hop_len = np.prod(config.model_params.upsample_factors)
self.noise_level = None
self.num_steps = None
self.beta = None
self.alpha = None
self.alpha_hat = None
self.noise_level = None
self.c1 = None
self.c2 = None
self.sigma = None
# dblocks
self.y_conv = Conv1d(1, y_conv_channels, 5, padding=2)
self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2)
self.dblocks = nn.ModuleList([])
ic = y_conv_channels
for oc, df in zip(dblock_out_channels, reversed(upsample_factors)):
ic = config.model_params.y_conv_channels
for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)):
self.dblocks.append(DBlock(ic, oc, df))
ic = oc
# film
self.film = nn.ModuleList([])
ic = y_conv_channels
for oc in reversed(ublock_out_channels):
ic = config.model_params.y_conv_channels
for oc in reversed(config.model_params.ublock_out_channels):
self.film.append(FiLM(ic, oc))
ic = oc
# ublocks
# ublocksn
self.ublocks = nn.ModuleList([])
ic = x_conv_channels
for oc, uf, ud in zip(ublock_out_channels, upsample_factors, upsample_dilations):
ic = config.model_params.x_conv_channels
for oc, uf, ud in zip(
config.model_params.ublock_out_channels,
config.model_params.upsample_factors,
config.model_params.upsample_dilations,
):
self.ublocks.append(UBlock(ic, oc, uf, ud))
ic = oc
self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1)
self.out_conv = Conv1d(oc, out_channels, 3, padding=1)
self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1)
self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1)
if use_weight_norm:
if config.model_params.use_weight_norm:
self.apply_weight_norm()
def forward(self, x, spectrogram, noise_scale):
@ -180,7 +220,7 @@ class Wavegrad(nn.Module):
if eval:
self.eval()
assert not self.training
if self.use_weight_norm:
if self.config.model_params.use_weight_norm:
self.remove_weight_norm()
betas = np.linspace(
config["test_noise_schedule"]["min_val"],
@ -195,3 +235,93 @@ class Wavegrad(nn.Module):
config["train_noise_schedule"]["num_steps"],
)
self.compute_noise_level(betas)
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
# format data
x = batch["input"]
y = batch["waveform"]
# set noise scale
noise, x_noisy, noise_scale = self.compute_y_n(y)
# forward pass
noise_hat = self.forward(x_noisy, x, noise_scale)
# compute losses
loss = criterion(noise, noise_hat)
return {"model_output": noise_hat}, {"loss": loss}
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
return None, None
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion)
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
return None, None
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument
# setup noise schedule and inference
noise_schedule = self.config["test_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)
for sample in samples:
x = sample["input"]
y = sample["waveform"]
# compute voice
y_pred = self.inference(x)
# compute spectrograms
figures = plot_results(y_pred, y, ap, "test")
# Sample audio
sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy()
return figures, {"test/audio": sample_voice}
def get_optimizer(self):
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
def get_scheduler(self, optimizer):
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
def get_criterion(self):
return torch.nn.L1Loss()
@staticmethod
def format_batch(batch: Dict) -> Dict:
# return a whole audio segment
m, y = batch[0], batch[1]
y = y.unsqueeze(1)
return {"input": m, "waveform": y}
def get_data_loader(
self, config: Coqpit, ap: AudioProcessor, is_eval: True, data_items: List, verbose: bool, num_gpus: int
):
dataset = WaveGradDataset(
ap=ap,
items=data_items,
seq_len=self.config.seq_len,
hop_len=ap.hop_length,
pad_short=self.config.pad_short,
conv_pad=self.config.conv_pad,
is_training=not is_eval,
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=num_gpus <= 1,
drop_last=False,
sampler=sampler,
num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers,
pin_memory=False,
)
return loader
def on_epoch_start(self, trainer): # pylint: disable=unused-argument
noise_schedule = self.config["train_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)

View File

@ -1,13 +1,21 @@
import sys
import time
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from coqpit import Coqpit
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
# fix this
from TTS.utils.audio import AudioProcessor as ap
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian
@ -135,89 +143,145 @@ class Upsample(nn.Module):
return m.transpose(1, 2), aux
class WaveRNN(nn.Module):
def __init__(
self,
rnn_dims,
fc_dims,
mode,
mulaw,
pad,
use_aux_net,
use_upsample_net,
upsample_factors,
feat_dims,
compute_dims,
res_out_dims,
num_res_blocks,
hop_length,
sample_rate,
):
@dataclass
class WavernnArgs(Coqpit):
"""🐸 WaveRNN model arguments.
rnn_dims (int):
Number of hidden channels in RNN layers. Defaults to 512.
fc_dims (int):
Number of hidden channels in fully-conntected layers. Defaults to 512.
compute_dims (int):
Number of hidden channels in the feature ResNet. Defaults to 128.
res_out_dim (int):
Number of hidden channels in the feature ResNet output. Defaults to 128.
num_res_blocks (int):
Number of residual blocks in the ResNet. Defaults to 10.
use_aux_net (bool):
enable/disable the feature ResNet. Defaults to True.
use_upsample_net (bool):
enable/ disable the upsampling networl. If False, basic upsampling is used. Defaults to True.
upsample_factors (list):
Upsampling factors. The multiply of the values must match the `hop_length`. Defaults to ```[4, 8, 8]```.
mode (str):
Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single
Gaussian Distribution and `bits` for quantized bits as the model's output.
mulaw (bool):
enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults
to `True`.
pad (int):
Padding applied to the input feature frames against the convolution layers of the feature network.
Defaults to 2.
"""
rnn_dims: int = 512
fc_dims: int = 512
compute_dims: int = 128
res_out_dims: int = 128
num_res_blocks: int = 10
use_aux_net: bool = True
use_upsample_net: bool = True
upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8])
mode: str = "mold" # mold [string], gauss [string], bits [int]
mulaw: bool = True # apply mulaw if mode is bits
pad: int = 2
feat_dims: int = 80
class Wavernn(BaseVocoder):
def __init__(self, config: Coqpit):
"""🐸 WaveRNN model.
Original paper - https://arxiv.org/abs/1802.08435
Official implementation - https://github.com/fatchord/WaveRNN
Args:
config (Coqpit): [description]
Raises:
RuntimeError: [description]
Examples:
>>> from TTS.vocoder.configs import WavernnConfig
>>> config = WavernnConfig()
>>> model = Wavernn(config)
Paper Abstract:
Sequential models achieve state-of-the-art results in audio, visual and textual domains with respect to
both estimating the data distribution and generating high-quality samples. Efficient sampling for this
class of models has however remained an elusive problem. With a focus on text-to-speech synthesis, we
describe a set of general techniques for reducing sampling time while maintaining high output quality.
We first describe a single-layer recurrent neural network, the WaveRNN, with a dual softmax layer that
matches the quality of the state-of-the-art WaveNet model. The compact form of the network makes it
possible to generate 24kHz 16-bit audio 4x faster than real time on a GPU. Second, we apply a weight
pruning technique to reduce the number of weights in the WaveRNN. We find that, for a constant number of
parameters, large sparse networks perform better than small dense networks and this relationship holds for
sparsity levels beyond 96%. The small number of weights in a Sparse WaveRNN makes it possible to sample
high-fidelity audio on a mobile CPU in real time. Finally, we propose a new generation scheme based on
subscaling that folds a long sequence into a batch of shorter sequences and allows one to generate multiple
samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an
orthogonal method for increasing sampling efficiency.
"""
super().__init__()
self.mode = mode
self.mulaw = mulaw
self.pad = pad
self.use_upsample_net = use_upsample_net
self.use_aux_net = use_aux_net
if isinstance(self.mode, int):
self.n_classes = 2 ** self.mode
elif self.mode == "mold":
self.args = config.model_params
self.config = config
if isinstance(self.args.mode, int):
self.n_classes = 2 ** self.args.mode
elif self.args.mode == "mold":
self.n_classes = 3 * 10
elif self.mode == "gauss":
elif self.args.mode == "gauss":
self.n_classes = 2
else:
raise RuntimeError("Unknown model mode value - ", self.mode)
raise RuntimeError("Unknown model mode value - ", self.args.mode)
self.rnn_dims = rnn_dims
self.aux_dims = res_out_dims // 4
self.hop_length = hop_length
self.sample_rate = sample_rate
self.aux_dims = self.args.res_out_dims // 4
if self.use_upsample_net:
if self.args.use_upsample_net:
assert (
np.cumproduct(upsample_factors)[-1] == self.hop_length
np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length
), " [!] upsample scales needs to be equal to hop_length"
self.upsample = UpsampleNetwork(
feat_dims,
upsample_factors,
compute_dims,
num_res_blocks,
res_out_dims,
pad,
use_aux_net,
self.args.feat_dims,
self.args.upsample_factors,
self.args.compute_dims,
self.args.num_res_blocks,
self.args.res_out_dims,
self.args.pad,
self.args.use_aux_net,
)
else:
self.upsample = Upsample(
hop_length,
pad,
num_res_blocks,
feat_dims,
compute_dims,
res_out_dims,
use_aux_net,
config.audio.hop_length,
self.args.pad,
self.args.num_res_blocks,
self.args.feat_dims,
self.args.compute_dims,
self.args.res_out_dims,
self.args.use_aux_net,
)
if self.use_aux_net:
self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
self.fc3 = nn.Linear(fc_dims, self.n_classes)
if self.args.use_aux_net:
self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims)
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True)
self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims)
self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims)
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
else:
self.I = nn.Linear(feat_dims + 1, rnn_dims)
self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
self.rnn2 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
self.fc1 = nn.Linear(rnn_dims, fc_dims)
self.fc2 = nn.Linear(fc_dims, fc_dims)
self.fc3 = nn.Linear(fc_dims, self.n_classes)
self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims)
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims)
self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims)
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
def forward(self, x, mels):
bsize = x.size(0)
h1 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
h2 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
mels, aux = self.upsample(mels)
if self.use_aux_net:
if self.args.use_aux_net:
aux_idx = [self.aux_dims * i for i in range(5)]
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
@ -226,7 +290,7 @@ class WaveRNN(nn.Module):
x = (
torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
if self.use_aux_net
if self.args.use_aux_net
else torch.cat([x.unsqueeze(-1), mels], dim=2)
)
x = self.I(x)
@ -236,15 +300,15 @@ class WaveRNN(nn.Module):
x = x + res
res = x
x = torch.cat([x, a2], dim=2) if self.use_aux_net else x
x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x
self.rnn2.flatten_parameters()
x, _ = self.rnn2(x, h2)
x = x + res
x = torch.cat([x, a3], dim=2) if self.use_aux_net else x
x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x
x = F.relu(self.fc1(x))
x = torch.cat([x, a4], dim=2) if self.use_aux_net else x
x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x
x = F.relu(self.fc2(x))
return self.fc3(x)
@ -262,9 +326,9 @@ class WaveRNN(nn.Module):
if mels.ndim == 2:
mels = mels.unsqueeze(0)
wave_len = (mels.size(-1) - 1) * self.hop_length
wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both")
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both")
mels, aux = self.upsample(mels.transpose(1, 2))
if batched:
@ -274,11 +338,11 @@ class WaveRNN(nn.Module):
b_size, seq_len, _ = mels.size()
h1 = torch.zeros(b_size, self.rnn_dims).type_as(mels)
h2 = torch.zeros(b_size, self.rnn_dims).type_as(mels)
h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
x = torch.zeros(b_size, 1).type_as(mels)
if self.use_aux_net:
if self.args.use_aux_net:
d = self.aux_dims
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
@ -286,35 +350,35 @@ class WaveRNN(nn.Module):
m_t = mels[:, i, :]
if self.use_aux_net:
if self.args.use_aux_net:
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
x = torch.cat([x, m_t, a1_t], dim=1) if self.use_aux_net else torch.cat([x, m_t], dim=1)
x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1)
x = self.I(x)
h1 = rnn1(x, h1)
x = x + h1
inp = torch.cat([x, a2_t], dim=1) if self.use_aux_net else x
inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x
h2 = rnn2(inp, h2)
x = x + h2
x = torch.cat([x, a3_t], dim=1) if self.use_aux_net else x
x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x
x = F.relu(self.fc1(x))
x = torch.cat([x, a4_t], dim=1) if self.use_aux_net else x
x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x
x = F.relu(self.fc2(x))
logits = self.fc3(x)
if self.mode == "mold":
if self.args.mode == "mold":
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
x = sample.transpose(0, 1).type_as(mels)
elif self.mode == "gauss":
elif self.args.mode == "gauss":
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
x = sample.transpose(0, 1).type_as(mels)
elif isinstance(self.mode, int):
elif isinstance(self.args.mode, int):
posterior = F.softmax(logits, dim=1)
distrib = torch.distributions.Categorical(posterior)
@ -322,7 +386,7 @@ class WaveRNN(nn.Module):
output.append(sample)
x = sample.unsqueeze(-1)
else:
raise RuntimeError("Unknown model mode value - ", self.mode)
raise RuntimeError("Unknown model mode value - ", self.args.mode)
if i % 100 == 0:
self.gen_display(i, seq_len, b_size, start)
@ -337,22 +401,22 @@ class WaveRNN(nn.Module):
else:
output = output[0]
if self.mulaw and isinstance(self.mode, int):
output = ap.mulaw_decode(output, self.mode)
if self.args.mulaw and isinstance(self.args.mode, int):
output = AudioProcessor.mulaw_decode(output, self.args.mode)
# Fade-out at the end to avoid signal cutting out suddenly
fade_out = np.linspace(1, 0, 20 * self.hop_length)
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
output = output[:wave_len]
if wave_len > len(fade_out):
output[-20 * self.hop_length :] *= fade_out
output[-20 * self.config.audio.hop_length :] *= fade_out
self.train()
return output
def gen_display(self, i, seq_len, b_size, start):
gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
realtime_ratio = gen_rate * 1000 / self.sample_rate
realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate
stream(
"%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ",
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
@ -486,3 +550,83 @@ class WaveRNN(nn.Module):
if eval:
self.eval()
assert not self.training
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
mels = batch["input"]
waveform = batch["waveform"]
waveform_coarse = batch["waveform_coarse"]
y_hat = self.forward(waveform, mels)
if isinstance(self.args.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
waveform_coarse = waveform_coarse.float()
waveform_coarse = waveform_coarse.unsqueeze(-1)
# compute losses
loss_dict = criterion(y_hat, waveform_coarse)
return {"model_output": y_hat}, loss_dict
def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion)
@torch.no_grad()
def test_run(
self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, Dict]:
figures = {}
audios = {}
for idx, sample in enumerate(samples):
x = sample["input"]
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
x_hat = ap.melspectrogram(y_hat)
figures.update(
{
f"test_{idx}/ground_truth": plot_spectrogram(x.T),
f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
}
)
audios.update({f"test_{idx}/audio", y_hat})
return figures, audios
@staticmethod
def format_batch(batch: Dict) -> Dict:
waveform = batch[0]
mels = batch[1]
waveform_coarse = batch[2]
return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse}
def get_data_loader( # pylint: disable=no-self-use
self,
config: Coqpit,
ap: AudioProcessor,
is_eval: True,
data_items: List,
verbose: bool,
num_gpus: int,
):
dataset = WaveRNNDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad=config.model_params.pad,
mode=config.model_params.mode,
mulaw=config.model_params.mulaw,
is_training=not is_eval,
verbose=verbose,
)
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=1 if is_eval else config.batch_size,
shuffle=num_gpus == 0,
collate_fn=dataset.collate,
sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=True,
)
return loader
def get_criterion(self):
# define train functions
return WaveRNNLoss(self.args.mode)