mirror of https://github.com/coqui-ai/TTS.git
Fix WaveGrad `test_run`
This commit is contained in:
parent
25832eb97b
commit
58cc414477
|
@ -765,8 +765,6 @@ class Trainer:
|
|||
"""Run test and log the results. Test run must be defined by the model.
|
||||
Model must return figures and audios to be logged by the Tensorboard."""
|
||||
if hasattr(self.model, "test_run"):
|
||||
if isinstance(self.model, Wavegrad):
|
||||
return None # TODO: Fix inference on WaveGrad
|
||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||
|
|
|
@ -2,6 +2,7 @@ import glob
|
|||
import os
|
||||
import random
|
||||
from multiprocessing import Manager
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -67,7 +68,19 @@ class WaveGradDataset(Dataset):
|
|||
item = self.load_item(idx)
|
||||
return item
|
||||
|
||||
def load_test_samples(self, num_samples):
|
||||
def load_test_samples(self, num_samples: int) -> List[Tuple]:
|
||||
"""Return test samples.
|
||||
|
||||
Args:
|
||||
num_samples (int): Number of samples to return.
|
||||
|
||||
Returns:
|
||||
List[Tuple]: melspectorgram and audio.
|
||||
|
||||
Shapes:
|
||||
- melspectrogram (Tensor): :math:`[C, T]`
|
||||
- audio (Tensor): :math:`[T_audio]`
|
||||
"""
|
||||
samples = []
|
||||
return_segments = self.return_segments
|
||||
self.return_segments = False
|
||||
|
|
|
@ -124,11 +124,16 @@ class Wavegrad(BaseModel):
|
|||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, y_n=None):
|
||||
"""x: B x D X T"""
|
||||
"""
|
||||
Shapes:
|
||||
x: :math:`[B, C , T]`
|
||||
y_n: :math:`[B, 1, T]`
|
||||
"""
|
||||
if y_n is None:
|
||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
|
||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
|
||||
else:
|
||||
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x)
|
||||
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
|
||||
y_n = y_n.type_as(x)
|
||||
sqrt_alpha_hat = self.noise_level.to(x)
|
||||
for n in range(len(self.alpha) - 1, -1, -1):
|
||||
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
|
||||
|
@ -267,10 +272,10 @@ class Wavegrad(BaseModel):
|
|||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
for sample in samples:
|
||||
sample = self.format_batch(sample)
|
||||
x = sample["input"]
|
||||
x = x.to(next(self.parameters()).device)
|
||||
y = sample["waveform"]
|
||||
x = sample[0]
|
||||
x = x[None, : , :].to(next(self.parameters()).device)
|
||||
y = sample[1]
|
||||
y = y[None, :]
|
||||
# compute voice
|
||||
y_pred = self.inference(x)
|
||||
# compute spectrograms
|
||||
|
|
Loading…
Reference in New Issue