Fix tests

This commit is contained in:
WeberJulian 2021-07-13 23:00:34 +02:00
parent 32974dd6a9
commit 7d92b30946
4 changed files with 19 additions and 11 deletions

View File

@ -22,6 +22,7 @@ from torch.utils.data import DataLoader
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.tts.datasets import load_meta_data from TTS.tts.datasets import load_meta_data
from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.models import setup_model as setup_tts_model
from TTS.vocoder.models.wavegrad import Wavegrad
from TTS.tts.utils.text.symbols import parse_symbols from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.callbacks import TrainerCallback from TTS.utils.callbacks import TrainerCallback
@ -764,11 +765,13 @@ class Trainer:
"""Run test and log the results. Test run must be defined by the model. """Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard.""" Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"): if hasattr(self.model, "test_run"):
if hasattr(self.eval_loader, "load_test_samples"): if isinstance(self.model, Wavegrad):
samples = self.eval_loader.load_test_samples(1) return None # TODO: Fix inference on WaveGrad
figures, audios = self.model.test_run(samples) elif hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None, self.use_cuda)
else: else:
figures, audios = self.model.test_run(use_cuda=self.use_cuda, ap=self.ap) figures, audios = self.model.test_run(self.ap, self.use_cuda)
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, figures) self.tb_logger.tb_test_figures(self.total_steps_done, figures)

View File

@ -70,7 +70,7 @@ class BaseTTS(BaseModel):
def get_aux_input(self, **kwargs) -> Dict: def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`""" """Prepare and return `aux_input` used by `forward()`"""
pass return {"speaker_id": None, "style_wav": None, "d_vector": None}
def format_batch(self, batch: Dict) -> Dict: def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`. """Generic batch formatting for `TTSDataset`.
@ -200,7 +200,7 @@ class BaseTTS(BaseModel):
) )
return loader return loader
def test_run(self, use_cuda=True, ap=None) -> Tuple[Dict, Dict]: def test_run(self, ap, use_cuda) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`. """Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour. You can override this for a different behaviour.

View File

@ -261,13 +261,16 @@ class Wavegrad(BaseModel):
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
return None, None return None, None
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict, use_cuda): # pylint: disable=unused-argument
# setup noise schedule and inference # setup noise schedule and inference
noise_schedule = self.config["test_noise_schedule"] noise_schedule = self.config["test_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas) self.compute_noise_level(betas)
for sample in samples: for sample in samples:
sample = self.format_batch(sample)
x = sample["input"] x = sample["input"]
if use_cuda:
x = x.cuda()
y = sample["waveform"] y = sample["waveform"]
# compute voice # compute voice
y_pred = self.inference(x) y_pred = self.inference(x)

View File

@ -322,7 +322,7 @@ class Wavernn(BaseVocoder):
with torch.no_grad(): with torch.no_grad():
if isinstance(mels, np.ndarray): if isinstance(mels, np.ndarray):
mels = torch.FloatTensor(mels).type_as(mels) mels = torch.FloatTensor(mels)
if mels.ndim == 2: if mels.ndim == 2:
mels = mels.unsqueeze(0) mels = mels.unsqueeze(0)
@ -571,12 +571,14 @@ class Wavernn(BaseVocoder):
@torch.no_grad() @torch.no_grad()
def test_run( def test_run(
self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument self, ap: AudioProcessor, samples: List[Dict], output: Dict, use_cuda # pylint: disable=unused-argument
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
figures = {} figures = {}
audios = {} audios = {}
for idx, sample in enumerate(samples): for idx, sample in enumerate(samples):
x = sample["input"] x = torch.FloatTensor(sample[0])
if use_cuda:
x = x.cuda()
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples) y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
x_hat = ap.melspectrogram(y_hat) x_hat = ap.melspectrogram(y_hat)
figures.update( figures.update(
@ -585,7 +587,7 @@ class Wavernn(BaseVocoder):
f"test_{idx}/prediction": plot_spectrogram(x_hat.T), f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
} }
) )
audios.update({f"test_{idx}/audio", y_hat}) audios.update({f"test_{idx}/audio": y_hat})
return figures, audios return figures, audios
@staticmethod @staticmethod