From 7d92b309465b58005450b1cc6459ef6e119c8453 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Tue, 13 Jul 2021 23:00:34 +0200 Subject: [PATCH] Fix tests --- TTS/trainer.py | 11 +++++++---- TTS/tts/models/base_tts.py | 4 ++-- TTS/vocoder/models/wavegrad.py | 5 ++++- TTS/vocoder/models/wavernn.py | 10 ++++++---- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index bbd9665a..b2494bad 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -22,6 +22,7 @@ from torch.utils.data import DataLoader from TTS.config import load_config, register_config from TTS.tts.datasets import load_meta_data from TTS.tts.models import setup_model as setup_tts_model +from TTS.vocoder.models.wavegrad import Wavegrad from TTS.tts.utils.text.symbols import parse_symbols from TTS.utils.audio import AudioProcessor from TTS.utils.callbacks import TrainerCallback @@ -764,11 +765,13 @@ class Trainer: """Run test and log the results. Test run must be defined by the model. Model must return figures and audios to be logged by the Tensorboard.""" if hasattr(self.model, "test_run"): - if hasattr(self.eval_loader, "load_test_samples"): - samples = self.eval_loader.load_test_samples(1) - figures, audios = self.model.test_run(samples) + if isinstance(self.model, Wavegrad): + return None # TODO: Fix inference on WaveGrad + elif hasattr(self.eval_loader.dataset, "load_test_samples"): + samples = self.eval_loader.dataset.load_test_samples(1) + figures, audios = self.model.test_run(self.ap, samples, None, self.use_cuda) else: - figures, audios = self.model.test_run(use_cuda=self.use_cuda, ap=self.ap) + figures, audios = self.model.test_run(self.ap, self.use_cuda) self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_figures(self.total_steps_done, figures) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 64c0ba6f..a30c5f02 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -70,7 +70,7 @@ class BaseTTS(BaseModel): def get_aux_input(self, **kwargs) -> Dict: """Prepare and return `aux_input` used by `forward()`""" - pass + return {"speaker_id": None, "style_wav": None, "d_vector": None} def format_batch(self, batch: Dict) -> Dict: """Generic batch formatting for `TTSDataset`. @@ -200,7 +200,7 @@ class BaseTTS(BaseModel): ) return loader - def test_run(self, use_cuda=True, ap=None) -> Tuple[Dict, Dict]: + def test_run(self, ap, use_cuda) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 03d5160e..7781b5f5 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -261,13 +261,16 @@ class Wavegrad(BaseModel): def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: return None, None - def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument + def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict, use_cuda): # pylint: disable=unused-argument # setup noise schedule and inference noise_schedule = self.config["test_noise_schedule"] betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) for sample in samples: + sample = self.format_batch(sample) x = sample["input"] + if use_cuda: + x = x.cuda() y = sample["waveform"] # compute voice y_pred = self.inference(x) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index a5d89d5a..12a29a72 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -322,7 +322,7 @@ class Wavernn(BaseVocoder): with torch.no_grad(): if isinstance(mels, np.ndarray): - mels = torch.FloatTensor(mels).type_as(mels) + mels = torch.FloatTensor(mels) if mels.ndim == 2: mels = mels.unsqueeze(0) @@ -571,12 +571,14 @@ class Wavernn(BaseVocoder): @torch.no_grad() def test_run( - self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument + self, ap: AudioProcessor, samples: List[Dict], output: Dict, use_cuda # pylint: disable=unused-argument ) -> Tuple[Dict, Dict]: figures = {} audios = {} for idx, sample in enumerate(samples): - x = sample["input"] + x = torch.FloatTensor(sample[0]) + if use_cuda: + x = x.cuda() y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples) x_hat = ap.melspectrogram(y_hat) figures.update( @@ -585,7 +587,7 @@ class Wavernn(BaseVocoder): f"test_{idx}/prediction": plot_spectrogram(x_hat.T), } ) - audios.update({f"test_{idx}/audio", y_hat}) + audios.update({f"test_{idx}/audio": y_hat}) return figures, audios @staticmethod