mirror of https://github.com/coqui-ai/TTS.git
Fix WaveGrad `test_run`
This commit is contained in:
parent
25832eb97b
commit
58cc414477
|
@ -765,8 +765,6 @@ class Trainer:
|
||||||
"""Run test and log the results. Test run must be defined by the model.
|
"""Run test and log the results. Test run must be defined by the model.
|
||||||
Model must return figures and audios to be logged by the Tensorboard."""
|
Model must return figures and audios to be logged by the Tensorboard."""
|
||||||
if hasattr(self.model, "test_run"):
|
if hasattr(self.model, "test_run"):
|
||||||
if isinstance(self.model, Wavegrad):
|
|
||||||
return None # TODO: Fix inference on WaveGrad
|
|
||||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import glob
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -67,7 +68,19 @@ class WaveGradDataset(Dataset):
|
||||||
item = self.load_item(idx)
|
item = self.load_item(idx)
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def load_test_samples(self, num_samples):
|
def load_test_samples(self, num_samples: int) -> List[Tuple]:
|
||||||
|
"""Return test samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_samples (int): Number of samples to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple]: melspectorgram and audio.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- melspectrogram (Tensor): :math:`[C, T]`
|
||||||
|
- audio (Tensor): :math:`[T_audio]`
|
||||||
|
"""
|
||||||
samples = []
|
samples = []
|
||||||
return_segments = self.return_segments
|
return_segments = self.return_segments
|
||||||
self.return_segments = False
|
self.return_segments = False
|
||||||
|
|
|
@ -124,11 +124,16 @@ class Wavegrad(BaseModel):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x, y_n=None):
|
def inference(self, x, y_n=None):
|
||||||
"""x: B x D X T"""
|
"""
|
||||||
|
Shapes:
|
||||||
|
x: :math:`[B, C , T]`
|
||||||
|
y_n: :math:`[B, 1, T]`
|
||||||
|
"""
|
||||||
if y_n is None:
|
if y_n is None:
|
||||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
|
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
|
||||||
else:
|
else:
|
||||||
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x)
|
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
|
||||||
|
y_n = y_n.type_as(x)
|
||||||
sqrt_alpha_hat = self.noise_level.to(x)
|
sqrt_alpha_hat = self.noise_level.to(x)
|
||||||
for n in range(len(self.alpha) - 1, -1, -1):
|
for n in range(len(self.alpha) - 1, -1, -1):
|
||||||
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
|
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
|
||||||
|
@ -267,10 +272,10 @@ class Wavegrad(BaseModel):
|
||||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||||
self.compute_noise_level(betas)
|
self.compute_noise_level(betas)
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
sample = self.format_batch(sample)
|
x = sample[0]
|
||||||
x = sample["input"]
|
x = x[None, : , :].to(next(self.parameters()).device)
|
||||||
x = x.to(next(self.parameters()).device)
|
y = sample[1]
|
||||||
y = sample["waveform"]
|
y = y[None, :]
|
||||||
# compute voice
|
# compute voice
|
||||||
y_pred = self.inference(x)
|
y_pred = self.inference(x)
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
|
|
Loading…
Reference in New Issue