From 20a677c6238384cd842c706bb0142f0751dcfc37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:36:27 +0100 Subject: [PATCH] Update test_run in wavernn and wavegrad --- TTS/vocoder/models/wavegrad.py | 7 ++++--- TTS/vocoder/models/wavernn.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 9d6e431c..58fc8762 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -270,12 +270,13 @@ class Wavegrad(BaseVocoder): ) -> None: pass - def test_run(self, assets: Dict, samples: List[Dict], outputs: Dict): # pylint: disable=unused-argument + def test(self, assets: Dict, test_loader:"DataLoader", outputs=None): # pylint: disable=unused-argument # setup noise schedule and inference ap = assets["audio_processor"] noise_schedule = self.config["test_noise_schedule"] betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) + samples = test_loader.dataset.load_test_samples(1) for sample in samples: x = sample[0] x = x[None, :, :].to(next(self.parameters()).device) @@ -307,12 +308,12 @@ class Wavegrad(BaseVocoder): return {"input": m, "waveform": y} def get_data_loader( - self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int + self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int ): ap = assets["audio_processor"] dataset = WaveGradDataset( ap=ap, - items=data_items, + items=samples, seq_len=self.config.seq_len, hop_len=ap.hop_length, pad_short=self.config.pad_short, diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 68f9b2c8..6686db45 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -568,12 +568,13 @@ class Wavernn(BaseVocoder): return self.train_step(batch, criterion) @torch.no_grad() - def test_run( - self, assets: Dict, samples: List[Dict], output: Dict # pylint: disable=unused-argument + def test( + self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument ) -> Tuple[Dict, Dict]: ap = assets["audio_processor"] figures = {} audios = {} + samples = test_loader.dataset.load_test_samples(1) for idx, sample in enumerate(samples): x = torch.FloatTensor(sample[0]) x = x.to(next(self.parameters()).device) @@ -600,14 +601,14 @@ class Wavernn(BaseVocoder): config: Coqpit, assets: Dict, is_eval: True, - data_items: List, + samples: List, verbose: bool, num_gpus: int, ): ap = assets["audio_processor"] dataset = WaveRNNDataset( ap=ap, - items=data_items, + items=samples, seq_len=config.seq_len, hop_len=ap.hop_length, pad=config.model_args.pad,