From 58cc414477df5dfdd897c56a860bcfb308002b7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 16 Jul 2021 13:02:25 +0200 Subject: [PATCH] Fix WaveGrad `test_run` --- TTS/trainer.py | 2 -- TTS/vocoder/datasets/wavegrad_dataset.py | 15 ++++++++++++++- TTS/vocoder/models/wavegrad.py | 19 ++++++++++++------- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index 12f43563..f3f45ebd 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -765,8 +765,6 @@ class Trainer: """Run test and log the results. Test run must be defined by the model. Model must return figures and audios to be logged by the Tensorboard.""" if hasattr(self.model, "test_run"): - if isinstance(self.model, Wavegrad): - return None # TODO: Fix inference on WaveGrad if hasattr(self.eval_loader.dataset, "load_test_samples"): samples = self.eval_loader.dataset.load_test_samples(1) figures, audios = self.model.test_run(self.ap, samples, None) diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index d99fc417..05e0fae8 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -2,6 +2,7 @@ import glob import os import random from multiprocessing import Manager +from typing import List, Tuple import numpy as np import torch @@ -67,7 +68,19 @@ class WaveGradDataset(Dataset): item = self.load_item(idx) return item - def load_test_samples(self, num_samples): + def load_test_samples(self, num_samples: int) -> List[Tuple]: + """Return test samples. + + Args: + num_samples (int): Number of samples to return. + + Returns: + List[Tuple]: melspectorgram and audio. + + Shapes: + - melspectrogram (Tensor): :math:`[C, T]` + - audio (Tensor): :math:`[T_audio]` + """ samples = [] return_segments = self.return_segments self.return_segments = False diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 9249f81c..22d2a015 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -124,11 +124,16 @@ class Wavegrad(BaseModel): @torch.no_grad() def inference(self, x, y_n=None): - """x: B x D X T""" + """ + Shapes: + x: :math:`[B, C , T]` + y_n: :math:`[B, 1, T]` + """ if y_n is None: - y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x) + y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1]) else: - y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x) + y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0) + y_n = y_n.type_as(x) sqrt_alpha_hat = self.noise_level.to(x) for n in range(len(self.alpha) - 1, -1, -1): y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) @@ -267,10 +272,10 @@ class Wavegrad(BaseModel): betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) for sample in samples: - sample = self.format_batch(sample) - x = sample["input"] - x = x.to(next(self.parameters()).device) - y = sample["waveform"] + x = sample[0] + x = x[None, : , :].to(next(self.parameters()).device) + y = sample[1] + y = y[None, :] # compute voice y_pred = self.inference(x) # compute spectrograms