Fix WaveGrad `test_run`

This commit is contained in:
Eren Gölge 2021-07-16 13:02:25 +02:00
parent 25832eb97b
commit 58cc414477
3 changed files with 26 additions and 10 deletions

View File

@ -765,8 +765,6 @@ class Trainer:
"""Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"):
if isinstance(self.model, Wavegrad):
return None # TODO: Fix inference on WaveGrad
if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None)

View File

@ -2,6 +2,7 @@ import glob
import os
import random
from multiprocessing import Manager
from typing import List, Tuple
import numpy as np
import torch
@ -67,7 +68,19 @@ class WaveGradDataset(Dataset):
item = self.load_item(idx)
return item
def load_test_samples(self, num_samples):
def load_test_samples(self, num_samples: int) -> List[Tuple]:
"""Return test samples.
Args:
num_samples (int): Number of samples to return.
Returns:
List[Tuple]: melspectorgram and audio.
Shapes:
- melspectrogram (Tensor): :math:`[C, T]`
- audio (Tensor): :math:`[T_audio]`
"""
samples = []
return_segments = self.return_segments
self.return_segments = False

View File

@ -124,11 +124,16 @@ class Wavegrad(BaseModel):
@torch.no_grad()
def inference(self, x, y_n=None):
"""x: B x D X T"""
"""
Shapes:
x: :math:`[B, C , T]`
y_n: :math:`[B, 1, T]`
"""
if y_n is None:
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
else:
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x)
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
y_n = y_n.type_as(x)
sqrt_alpha_hat = self.noise_level.to(x)
for n in range(len(self.alpha) - 1, -1, -1):
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
@ -267,10 +272,10 @@ class Wavegrad(BaseModel):
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)
for sample in samples:
sample = self.format_batch(sample)
x = sample["input"]
x = x.to(next(self.parameters()).device)
y = sample["waveform"]
x = sample[0]
x = x[None, : , :].to(next(self.parameters()).device)
y = sample[1]
y = y[None, :]
# compute voice
y_pred = self.inference(x)
# compute spectrograms