mirror of https://github.com/coqui-ai/TTS.git
Update vocoder models
This commit is contained in:
parent
51005cdab4
commit
e949e7ad58
|
@ -0,0 +1,246 @@
|
|||
from inspect import signature
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
from TTS.vocoder.models import setup_discriminator, setup_generator
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
class GAN(BaseVocoder):
|
||||
def __init__(self, config: Coqpit):
|
||||
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
|
||||
It also helps mixing and matching different generator and disciminator networks easily.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
|
||||
Examples:
|
||||
Initializing the GAN model with HifiGAN generator and discriminator.
|
||||
>>> from TTS.vocoder.configs import HifiganConfig
|
||||
>>> config = HifiganConfig()
|
||||
>>> model = GAN(config)
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_g = setup_generator(config)
|
||||
self.model_d = setup_discriminator(config)
|
||||
self.train_disc = False # if False, train only the generator.
|
||||
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.model_g.forward(x)
|
||||
|
||||
def inference(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.model_g.inference(x)
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
outputs = None
|
||||
loss_dict = None
|
||||
|
||||
x = batch["input"]
|
||||
y = batch["waveform"]
|
||||
|
||||
if optimizer_idx not in [0, 1]:
|
||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# GENERATOR
|
||||
# generator pass
|
||||
y_hat = self.model_g(x)[:, :, : y.size(2)]
|
||||
self.y_hat_g = y_hat # save for discriminator
|
||||
y_hat_sub = None
|
||||
y_sub = None
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat_sub = y_hat
|
||||
y_hat = self.model_g.pqmf_synthesis(y_hat)
|
||||
self.y_hat_g = y_hat # save for discriminator
|
||||
y_sub = self.model_g.pqmf_analysis(y)
|
||||
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
if self.train_disc:
|
||||
|
||||
if len(signature(self.model_d.forward).parameters) == 2:
|
||||
D_out_fake = self.model_d(y_hat, x)
|
||||
else:
|
||||
D_out_fake = self.model_d(y_hat)
|
||||
D_out_real = None
|
||||
|
||||
if self.config.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = self.model_d(y)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
feats_real = None
|
||||
else:
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
feats_fake, feats_real = None, None
|
||||
|
||||
# compute losses
|
||||
loss_dict = criterion[optimizer_idx](y_hat, y, scores_fake, feats_fake, feats_real, y_hat_sub, y_sub)
|
||||
outputs = {"model_outputs": y_hat}
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# DISCRIMINATOR
|
||||
if self.train_disc:
|
||||
# use different samples for G and D trainings
|
||||
if self.config.diff_samples_for_G_and_D:
|
||||
x_d = batch["input_disc"]
|
||||
y_d = batch["waveform_disc"]
|
||||
# use a different sample than generator
|
||||
with torch.no_grad():
|
||||
y_hat = self.model_g(x_d)
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = self.model_g.pqmf_synthesis(y_hat)
|
||||
else:
|
||||
# use the same samples as generator
|
||||
x_d = x.clone()
|
||||
y_d = y.clone()
|
||||
y_hat = self.y_hat_g
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(self.model_d.forward).parameters) == 2:
|
||||
D_out_fake = self.model_d(y_hat.detach().clone(), x_d)
|
||||
D_out_real = self.model_d(y_d, x_d)
|
||||
else:
|
||||
D_out_fake = self.model_d(y_hat.detach())
|
||||
D_out_real = self.model_d(y_d)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
# self.model_d returns scores and features
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
# model D returns only scores
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
|
||||
# compute losses
|
||||
loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
|
||||
outputs = {"model_outputs": y_hat}
|
||||
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
y_hat = outputs[0]["model_outputs"]
|
||||
y = batch["waveform"]
|
||||
figures = plot_results(y_hat, y, ap, "train")
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
audios = {"train/audio": sample_voice}
|
||||
return figures, audios
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
return self.train_log(ap, batch, outputs)
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
config: Coqpit,
|
||||
checkpoint_path: str,
|
||||
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
|
||||
) -> None:
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# band-aid for older than v0.0.15 GAN models
|
||||
if "model_disc" in state:
|
||||
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
||||
else:
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.model_d = None
|
||||
if hasattr(self.model_g, "remove_weight_norm"):
|
||||
self.model_g.remove_weight_norm()
|
||||
|
||||
def on_train_step_start(self, trainer) -> None:
|
||||
self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator
|
||||
|
||||
def get_optimizer(self):
|
||||
optimizer1 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
|
||||
)
|
||||
optimizer2 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
|
||||
)
|
||||
return [optimizer1, optimizer2]
|
||||
|
||||
def get_lr(self):
|
||||
return [self.config.lr_gen, self.config.lr_disc]
|
||||
|
||||
def get_scheduler(self, optimizer):
|
||||
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler1, scheduler2]
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch):
|
||||
if isinstance(batch[0], list):
|
||||
x_G, y_G = batch[0]
|
||||
x_D, y_D = batch[1]
|
||||
return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D}
|
||||
x, y = batch
|
||||
return {"input": x, "waveform": y}
|
||||
|
||||
def get_data_loader( # pylint: disable=no-self-use
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: AudioProcessor,
|
||||
is_eval: True,
|
||||
data_items: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
):
|
||||
dataset = GANDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
|
||||
is_training=not is_eval,
|
||||
return_segments=not is_eval,
|
||||
use_noise_augment=config.use_noise_augment,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset.shuffle_mapping()
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1 if is_eval else config.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_criterion(self):
|
||||
"""Return criterions for the optimizers"""
|
||||
return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)]
|
|
@ -1,65 +1,105 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from ..layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||
from TTS.model import BaseModel
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.datasets import WaveGradDataset
|
||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
class Wavegrad(nn.Module):
|
||||
@dataclass
|
||||
class WavegradArgs(Coqpit):
|
||||
in_channels: int = 80
|
||||
out_channels: int = 1
|
||||
use_weight_norm: bool = False
|
||||
y_conv_channels: int = 32
|
||||
x_conv_channels: int = 768
|
||||
dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512])
|
||||
ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128])
|
||||
upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2])
|
||||
upsample_dilations: List[List[int]] = field(
|
||||
default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]]
|
||||
)
|
||||
|
||||
|
||||
class Wavegrad(BaseModel):
|
||||
"""🐸 🌊 WaveGrad 🌊 model.
|
||||
Paper - https://arxiv.org/abs/2009.00713
|
||||
|
||||
Examples:
|
||||
Initializing the model.
|
||||
|
||||
>>> from TTS.vocoder.configs import WavegradConfig
|
||||
>>> config = WavegradConfig()
|
||||
>>> model = Wavegrad(config)
|
||||
|
||||
Paper Abstract:
|
||||
This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the
|
||||
data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts
|
||||
from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned
|
||||
on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting
|
||||
the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in
|
||||
terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations.
|
||||
Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive
|
||||
baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations.
|
||||
Audio samples are available at this https URL.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
use_weight_norm=False,
|
||||
y_conv_channels=32,
|
||||
x_conv_channels=768,
|
||||
dblock_out_channels=[128, 128, 256, 512],
|
||||
ublock_out_channels=[512, 512, 256, 128, 128],
|
||||
upsample_factors=[5, 5, 3, 2, 2],
|
||||
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]],
|
||||
):
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__()
|
||||
|
||||
self.use_weight_norm = use_weight_norm
|
||||
self.hop_len = np.prod(upsample_factors)
|
||||
self.config = config
|
||||
self.use_weight_norm = config.model_params.use_weight_norm
|
||||
self.hop_len = np.prod(config.model_params.upsample_factors)
|
||||
self.noise_level = None
|
||||
self.num_steps = None
|
||||
self.beta = None
|
||||
self.alpha = None
|
||||
self.alpha_hat = None
|
||||
self.noise_level = None
|
||||
self.c1 = None
|
||||
self.c2 = None
|
||||
self.sigma = None
|
||||
|
||||
# dblocks
|
||||
self.y_conv = Conv1d(1, y_conv_channels, 5, padding=2)
|
||||
self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2)
|
||||
self.dblocks = nn.ModuleList([])
|
||||
ic = y_conv_channels
|
||||
for oc, df in zip(dblock_out_channels, reversed(upsample_factors)):
|
||||
ic = config.model_params.y_conv_channels
|
||||
for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)):
|
||||
self.dblocks.append(DBlock(ic, oc, df))
|
||||
ic = oc
|
||||
|
||||
# film
|
||||
self.film = nn.ModuleList([])
|
||||
ic = y_conv_channels
|
||||
for oc in reversed(ublock_out_channels):
|
||||
ic = config.model_params.y_conv_channels
|
||||
for oc in reversed(config.model_params.ublock_out_channels):
|
||||
self.film.append(FiLM(ic, oc))
|
||||
ic = oc
|
||||
|
||||
# ublocks
|
||||
# ublocksn
|
||||
self.ublocks = nn.ModuleList([])
|
||||
ic = x_conv_channels
|
||||
for oc, uf, ud in zip(ublock_out_channels, upsample_factors, upsample_dilations):
|
||||
ic = config.model_params.x_conv_channels
|
||||
for oc, uf, ud in zip(
|
||||
config.model_params.ublock_out_channels,
|
||||
config.model_params.upsample_factors,
|
||||
config.model_params.upsample_dilations,
|
||||
):
|
||||
self.ublocks.append(UBlock(ic, oc, uf, ud))
|
||||
ic = oc
|
||||
|
||||
self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1)
|
||||
self.out_conv = Conv1d(oc, out_channels, 3, padding=1)
|
||||
self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1)
|
||||
self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1)
|
||||
|
||||
if use_weight_norm:
|
||||
if config.model_params.use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x, spectrogram, noise_scale):
|
||||
|
@ -180,7 +220,7 @@ class Wavegrad(nn.Module):
|
|||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
if self.use_weight_norm:
|
||||
if self.config.model_params.use_weight_norm:
|
||||
self.remove_weight_norm()
|
||||
betas = np.linspace(
|
||||
config["test_noise_schedule"]["min_val"],
|
||||
|
@ -195,3 +235,93 @@ class Wavegrad(nn.Module):
|
|||
config["train_noise_schedule"]["num_steps"],
|
||||
)
|
||||
self.compute_noise_level(betas)
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
# format data
|
||||
x = batch["input"]
|
||||
y = batch["waveform"]
|
||||
|
||||
# set noise scale
|
||||
noise, x_noisy, noise_scale = self.compute_y_n(y)
|
||||
|
||||
# forward pass
|
||||
noise_hat = self.forward(x_noisy, x, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
return {"model_output": noise_hat}, {"loss": loss}
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
return None, None
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
return None, None
|
||||
|
||||
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument
|
||||
# setup noise schedule and inference
|
||||
noise_schedule = self.config["test_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
for sample in samples:
|
||||
x = sample["input"]
|
||||
y = sample["waveform"]
|
||||
# compute voice
|
||||
y_pred = self.inference(x)
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_pred, y, ap, "test")
|
||||
# Sample audio
|
||||
sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy()
|
||||
return figures, {"test/audio": sample_voice}
|
||||
|
||||
def get_optimizer(self):
|
||||
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
|
||||
|
||||
def get_scheduler(self, optimizer):
|
||||
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
|
||||
|
||||
def get_criterion(self):
|
||||
return torch.nn.L1Loss()
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: Dict) -> Dict:
|
||||
# return a whole audio segment
|
||||
m, y = batch[0], batch[1]
|
||||
y = y.unsqueeze(1)
|
||||
return {"input": m, "waveform": y}
|
||||
|
||||
def get_data_loader(
|
||||
self, config: Coqpit, ap: AudioProcessor, is_eval: True, data_items: List, verbose: bool, num_gpus: int
|
||||
):
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=self.config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=self.config.pad_short,
|
||||
conv_pad=self.config.conv_pad,
|
||||
is_training=not is_eval,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.batch_size,
|
||||
shuffle=num_gpus <= 1,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def on_epoch_start(self, trainer): # pylint: disable=unused-argument
|
||||
noise_schedule = self.config["train_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
|
|
|
@ -1,13 +1,21 @@
|
|||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from coqpit import Coqpit
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
# fix this
|
||||
from TTS.utils.audio import AudioProcessor as ap
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian
|
||||
|
||||
|
||||
|
@ -135,89 +143,145 @@ class Upsample(nn.Module):
|
|||
return m.transpose(1, 2), aux
|
||||
|
||||
|
||||
class WaveRNN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
rnn_dims,
|
||||
fc_dims,
|
||||
mode,
|
||||
mulaw,
|
||||
pad,
|
||||
use_aux_net,
|
||||
use_upsample_net,
|
||||
upsample_factors,
|
||||
feat_dims,
|
||||
compute_dims,
|
||||
res_out_dims,
|
||||
num_res_blocks,
|
||||
hop_length,
|
||||
sample_rate,
|
||||
):
|
||||
@dataclass
|
||||
class WavernnArgs(Coqpit):
|
||||
"""🐸 WaveRNN model arguments.
|
||||
|
||||
rnn_dims (int):
|
||||
Number of hidden channels in RNN layers. Defaults to 512.
|
||||
fc_dims (int):
|
||||
Number of hidden channels in fully-conntected layers. Defaults to 512.
|
||||
compute_dims (int):
|
||||
Number of hidden channels in the feature ResNet. Defaults to 128.
|
||||
res_out_dim (int):
|
||||
Number of hidden channels in the feature ResNet output. Defaults to 128.
|
||||
num_res_blocks (int):
|
||||
Number of residual blocks in the ResNet. Defaults to 10.
|
||||
use_aux_net (bool):
|
||||
enable/disable the feature ResNet. Defaults to True.
|
||||
use_upsample_net (bool):
|
||||
enable/ disable the upsampling networl. If False, basic upsampling is used. Defaults to True.
|
||||
upsample_factors (list):
|
||||
Upsampling factors. The multiply of the values must match the `hop_length`. Defaults to ```[4, 8, 8]```.
|
||||
mode (str):
|
||||
Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single
|
||||
Gaussian Distribution and `bits` for quantized bits as the model's output.
|
||||
mulaw (bool):
|
||||
enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults
|
||||
to `True`.
|
||||
pad (int):
|
||||
Padding applied to the input feature frames against the convolution layers of the feature network.
|
||||
Defaults to 2.
|
||||
"""
|
||||
|
||||
rnn_dims: int = 512
|
||||
fc_dims: int = 512
|
||||
compute_dims: int = 128
|
||||
res_out_dims: int = 128
|
||||
num_res_blocks: int = 10
|
||||
use_aux_net: bool = True
|
||||
use_upsample_net: bool = True
|
||||
upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8])
|
||||
mode: str = "mold" # mold [string], gauss [string], bits [int]
|
||||
mulaw: bool = True # apply mulaw if mode is bits
|
||||
pad: int = 2
|
||||
feat_dims: int = 80
|
||||
|
||||
|
||||
class Wavernn(BaseVocoder):
|
||||
def __init__(self, config: Coqpit):
|
||||
"""🐸 WaveRNN model.
|
||||
Original paper - https://arxiv.org/abs/1802.08435
|
||||
Official implementation - https://github.com/fatchord/WaveRNN
|
||||
|
||||
Args:
|
||||
config (Coqpit): [description]
|
||||
|
||||
Raises:
|
||||
RuntimeError: [description]
|
||||
|
||||
Examples:
|
||||
>>> from TTS.vocoder.configs import WavernnConfig
|
||||
>>> config = WavernnConfig()
|
||||
>>> model = Wavernn(config)
|
||||
|
||||
Paper Abstract:
|
||||
Sequential models achieve state-of-the-art results in audio, visual and textual domains with respect to
|
||||
both estimating the data distribution and generating high-quality samples. Efficient sampling for this
|
||||
class of models has however remained an elusive problem. With a focus on text-to-speech synthesis, we
|
||||
describe a set of general techniques for reducing sampling time while maintaining high output quality.
|
||||
We first describe a single-layer recurrent neural network, the WaveRNN, with a dual softmax layer that
|
||||
matches the quality of the state-of-the-art WaveNet model. The compact form of the network makes it
|
||||
possible to generate 24kHz 16-bit audio 4x faster than real time on a GPU. Second, we apply a weight
|
||||
pruning technique to reduce the number of weights in the WaveRNN. We find that, for a constant number of
|
||||
parameters, large sparse networks perform better than small dense networks and this relationship holds for
|
||||
sparsity levels beyond 96%. The small number of weights in a Sparse WaveRNN makes it possible to sample
|
||||
high-fidelity audio on a mobile CPU in real time. Finally, we propose a new generation scheme based on
|
||||
subscaling that folds a long sequence into a batch of shorter sequences and allows one to generate multiple
|
||||
samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an
|
||||
orthogonal method for increasing sampling efficiency.
|
||||
"""
|
||||
super().__init__()
|
||||
self.mode = mode
|
||||
self.mulaw = mulaw
|
||||
self.pad = pad
|
||||
self.use_upsample_net = use_upsample_net
|
||||
self.use_aux_net = use_aux_net
|
||||
if isinstance(self.mode, int):
|
||||
self.n_classes = 2 ** self.mode
|
||||
elif self.mode == "mold":
|
||||
|
||||
self.args = config.model_params
|
||||
self.config = config
|
||||
|
||||
if isinstance(self.args.mode, int):
|
||||
self.n_classes = 2 ** self.args.mode
|
||||
elif self.args.mode == "mold":
|
||||
self.n_classes = 3 * 10
|
||||
elif self.mode == "gauss":
|
||||
elif self.args.mode == "gauss":
|
||||
self.n_classes = 2
|
||||
else:
|
||||
raise RuntimeError("Unknown model mode value - ", self.mode)
|
||||
raise RuntimeError("Unknown model mode value - ", self.args.mode)
|
||||
|
||||
self.rnn_dims = rnn_dims
|
||||
self.aux_dims = res_out_dims // 4
|
||||
self.hop_length = hop_length
|
||||
self.sample_rate = sample_rate
|
||||
self.aux_dims = self.args.res_out_dims // 4
|
||||
|
||||
if self.use_upsample_net:
|
||||
if self.args.use_upsample_net:
|
||||
assert (
|
||||
np.cumproduct(upsample_factors)[-1] == self.hop_length
|
||||
np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length
|
||||
), " [!] upsample scales needs to be equal to hop_length"
|
||||
self.upsample = UpsampleNetwork(
|
||||
feat_dims,
|
||||
upsample_factors,
|
||||
compute_dims,
|
||||
num_res_blocks,
|
||||
res_out_dims,
|
||||
pad,
|
||||
use_aux_net,
|
||||
self.args.feat_dims,
|
||||
self.args.upsample_factors,
|
||||
self.args.compute_dims,
|
||||
self.args.num_res_blocks,
|
||||
self.args.res_out_dims,
|
||||
self.args.pad,
|
||||
self.args.use_aux_net,
|
||||
)
|
||||
else:
|
||||
self.upsample = Upsample(
|
||||
hop_length,
|
||||
pad,
|
||||
num_res_blocks,
|
||||
feat_dims,
|
||||
compute_dims,
|
||||
res_out_dims,
|
||||
use_aux_net,
|
||||
config.audio.hop_length,
|
||||
self.args.pad,
|
||||
self.args.num_res_blocks,
|
||||
self.args.feat_dims,
|
||||
self.args.compute_dims,
|
||||
self.args.res_out_dims,
|
||||
self.args.use_aux_net,
|
||||
)
|
||||
if self.use_aux_net:
|
||||
self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
|
||||
self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
|
||||
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
|
||||
self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
|
||||
self.fc3 = nn.Linear(fc_dims, self.n_classes)
|
||||
if self.args.use_aux_net:
|
||||
self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims)
|
||||
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims)
|
||||
self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims)
|
||||
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
|
||||
else:
|
||||
self.I = nn.Linear(feat_dims + 1, rnn_dims)
|
||||
self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
|
||||
self.fc1 = nn.Linear(rnn_dims, fc_dims)
|
||||
self.fc2 = nn.Linear(fc_dims, fc_dims)
|
||||
self.fc3 = nn.Linear(fc_dims, self.n_classes)
|
||||
self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims)
|
||||
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims)
|
||||
self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims)
|
||||
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
|
||||
|
||||
def forward(self, x, mels):
|
||||
bsize = x.size(0)
|
||||
h1 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
|
||||
h2 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
|
||||
h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
|
||||
h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
|
||||
mels, aux = self.upsample(mels)
|
||||
|
||||
if self.use_aux_net:
|
||||
if self.args.use_aux_net:
|
||||
aux_idx = [self.aux_dims * i for i in range(5)]
|
||||
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
|
||||
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
|
||||
|
@ -226,7 +290,7 @@ class WaveRNN(nn.Module):
|
|||
|
||||
x = (
|
||||
torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
|
||||
if self.use_aux_net
|
||||
if self.args.use_aux_net
|
||||
else torch.cat([x.unsqueeze(-1), mels], dim=2)
|
||||
)
|
||||
x = self.I(x)
|
||||
|
@ -236,15 +300,15 @@ class WaveRNN(nn.Module):
|
|||
|
||||
x = x + res
|
||||
res = x
|
||||
x = torch.cat([x, a2], dim=2) if self.use_aux_net else x
|
||||
x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x
|
||||
self.rnn2.flatten_parameters()
|
||||
x, _ = self.rnn2(x, h2)
|
||||
|
||||
x = x + res
|
||||
x = torch.cat([x, a3], dim=2) if self.use_aux_net else x
|
||||
x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc1(x))
|
||||
|
||||
x = torch.cat([x, a4], dim=2) if self.use_aux_net else x
|
||||
x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc2(x))
|
||||
return self.fc3(x)
|
||||
|
||||
|
@ -262,9 +326,9 @@ class WaveRNN(nn.Module):
|
|||
|
||||
if mels.ndim == 2:
|
||||
mels = mels.unsqueeze(0)
|
||||
wave_len = (mels.size(-1) - 1) * self.hop_length
|
||||
wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length
|
||||
|
||||
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both")
|
||||
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both")
|
||||
mels, aux = self.upsample(mels.transpose(1, 2))
|
||||
|
||||
if batched:
|
||||
|
@ -274,11 +338,11 @@ class WaveRNN(nn.Module):
|
|||
|
||||
b_size, seq_len, _ = mels.size()
|
||||
|
||||
h1 = torch.zeros(b_size, self.rnn_dims).type_as(mels)
|
||||
h2 = torch.zeros(b_size, self.rnn_dims).type_as(mels)
|
||||
h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
|
||||
h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
|
||||
x = torch.zeros(b_size, 1).type_as(mels)
|
||||
|
||||
if self.use_aux_net:
|
||||
if self.args.use_aux_net:
|
||||
d = self.aux_dims
|
||||
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
|
||||
|
||||
|
@ -286,35 +350,35 @@ class WaveRNN(nn.Module):
|
|||
|
||||
m_t = mels[:, i, :]
|
||||
|
||||
if self.use_aux_net:
|
||||
if self.args.use_aux_net:
|
||||
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
|
||||
|
||||
x = torch.cat([x, m_t, a1_t], dim=1) if self.use_aux_net else torch.cat([x, m_t], dim=1)
|
||||
x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1)
|
||||
x = self.I(x)
|
||||
h1 = rnn1(x, h1)
|
||||
|
||||
x = x + h1
|
||||
inp = torch.cat([x, a2_t], dim=1) if self.use_aux_net else x
|
||||
inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x
|
||||
h2 = rnn2(inp, h2)
|
||||
|
||||
x = x + h2
|
||||
x = torch.cat([x, a3_t], dim=1) if self.use_aux_net else x
|
||||
x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc1(x))
|
||||
|
||||
x = torch.cat([x, a4_t], dim=1) if self.use_aux_net else x
|
||||
x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc2(x))
|
||||
|
||||
logits = self.fc3(x)
|
||||
|
||||
if self.mode == "mold":
|
||||
if self.args.mode == "mold":
|
||||
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).type_as(mels)
|
||||
elif self.mode == "gauss":
|
||||
elif self.args.mode == "gauss":
|
||||
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).type_as(mels)
|
||||
elif isinstance(self.mode, int):
|
||||
elif isinstance(self.args.mode, int):
|
||||
posterior = F.softmax(logits, dim=1)
|
||||
distrib = torch.distributions.Categorical(posterior)
|
||||
|
||||
|
@ -322,7 +386,7 @@ class WaveRNN(nn.Module):
|
|||
output.append(sample)
|
||||
x = sample.unsqueeze(-1)
|
||||
else:
|
||||
raise RuntimeError("Unknown model mode value - ", self.mode)
|
||||
raise RuntimeError("Unknown model mode value - ", self.args.mode)
|
||||
|
||||
if i % 100 == 0:
|
||||
self.gen_display(i, seq_len, b_size, start)
|
||||
|
@ -337,22 +401,22 @@ class WaveRNN(nn.Module):
|
|||
else:
|
||||
output = output[0]
|
||||
|
||||
if self.mulaw and isinstance(self.mode, int):
|
||||
output = ap.mulaw_decode(output, self.mode)
|
||||
if self.args.mulaw and isinstance(self.args.mode, int):
|
||||
output = AudioProcessor.mulaw_decode(output, self.args.mode)
|
||||
|
||||
# Fade-out at the end to avoid signal cutting out suddenly
|
||||
fade_out = np.linspace(1, 0, 20 * self.hop_length)
|
||||
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
|
||||
output = output[:wave_len]
|
||||
|
||||
if wave_len > len(fade_out):
|
||||
output[-20 * self.hop_length :] *= fade_out
|
||||
output[-20 * self.config.audio.hop_length :] *= fade_out
|
||||
|
||||
self.train()
|
||||
return output
|
||||
|
||||
def gen_display(self, i, seq_len, b_size, start):
|
||||
gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
|
||||
realtime_ratio = gen_rate * 1000 / self.sample_rate
|
||||
realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate
|
||||
stream(
|
||||
"%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ",
|
||||
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
|
||||
|
@ -486,3 +550,83 @@ class WaveRNN(nn.Module):
|
|||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
mels = batch["input"]
|
||||
waveform = batch["waveform"]
|
||||
waveform_coarse = batch["waveform_coarse"]
|
||||
|
||||
y_hat = self.forward(waveform, mels)
|
||||
if isinstance(self.args.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
waveform_coarse = waveform_coarse.float()
|
||||
waveform_coarse = waveform_coarse.unsqueeze(-1)
|
||||
# compute losses
|
||||
loss_dict = criterion(y_hat, waveform_coarse)
|
||||
return {"model_output": y_hat}, loss_dict
|
||||
|
||||
def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(
|
||||
self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, Dict]:
|
||||
figures = {}
|
||||
audios = {}
|
||||
for idx, sample in enumerate(samples):
|
||||
x = sample["input"]
|
||||
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
|
||||
x_hat = ap.melspectrogram(y_hat)
|
||||
figures.update(
|
||||
{
|
||||
f"test_{idx}/ground_truth": plot_spectrogram(x.T),
|
||||
f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
|
||||
}
|
||||
)
|
||||
audios.update({f"test_{idx}/audio", y_hat})
|
||||
return figures, audios
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: Dict) -> Dict:
|
||||
waveform = batch[0]
|
||||
mels = batch[1]
|
||||
waveform_coarse = batch[2]
|
||||
return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse}
|
||||
|
||||
def get_data_loader( # pylint: disable=no-self-use
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: AudioProcessor,
|
||||
is_eval: True,
|
||||
data_items: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
):
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=config.model_params.pad,
|
||||
mode=config.model_params.mode,
|
||||
mulaw=config.model_params.mulaw,
|
||||
is_training=not is_eval,
|
||||
verbose=verbose,
|
||||
)
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1 if is_eval else config.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
collate_fn=dataset.collate,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_criterion(self):
|
||||
# define train functions
|
||||
return WaveRNNLoss(self.args.mode)
|
||||
|
|
Loading…
Reference in New Issue