Fix WaveGrad `test_run`

This commit is contained in:
Eren Gölge 2021-07-16 13:02:25 +02:00
parent 25832eb97b
commit 58cc414477
3 changed files with 26 additions and 10 deletions

View File

@ -765,8 +765,6 @@ class Trainer:
"""Run test and log the results. Test run must be defined by the model. """Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard.""" Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"): if hasattr(self.model, "test_run"):
if isinstance(self.model, Wavegrad):
return None # TODO: Fix inference on WaveGrad
if hasattr(self.eval_loader.dataset, "load_test_samples"): if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1) samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None) figures, audios = self.model.test_run(self.ap, samples, None)

View File

@ -2,6 +2,7 @@ import glob
import os import os
import random import random
from multiprocessing import Manager from multiprocessing import Manager
from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch
@ -67,7 +68,19 @@ class WaveGradDataset(Dataset):
item = self.load_item(idx) item = self.load_item(idx)
return item return item
def load_test_samples(self, num_samples): def load_test_samples(self, num_samples: int) -> List[Tuple]:
"""Return test samples.
Args:
num_samples (int): Number of samples to return.
Returns:
List[Tuple]: melspectorgram and audio.
Shapes:
- melspectrogram (Tensor): :math:`[C, T]`
- audio (Tensor): :math:`[T_audio]`
"""
samples = [] samples = []
return_segments = self.return_segments return_segments = self.return_segments
self.return_segments = False self.return_segments = False

View File

@ -124,11 +124,16 @@ class Wavegrad(BaseModel):
@torch.no_grad() @torch.no_grad()
def inference(self, x, y_n=None): def inference(self, x, y_n=None):
"""x: B x D X T""" """
Shapes:
x: :math:`[B, C , T]`
y_n: :math:`[B, 1, T]`
"""
if y_n is None: if y_n is None:
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x) y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
else: else:
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x) y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
y_n = y_n.type_as(x)
sqrt_alpha_hat = self.noise_level.to(x) sqrt_alpha_hat = self.noise_level.to(x)
for n in range(len(self.alpha) - 1, -1, -1): for n in range(len(self.alpha) - 1, -1, -1):
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
@ -267,10 +272,10 @@ class Wavegrad(BaseModel):
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas) self.compute_noise_level(betas)
for sample in samples: for sample in samples:
sample = self.format_batch(sample) x = sample[0]
x = sample["input"] x = x[None, : , :].to(next(self.parameters()).device)
x = x.to(next(self.parameters()).device) y = sample[1]
y = sample["waveform"] y = y[None, :]
# compute voice # compute voice
y_pred = self.inference(x) y_pred = self.inference(x)
# compute spectrograms # compute spectrograms