diff --git a/TTS/trainer.py b/TTS/trainer.py index fd316e78..12f43563 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -769,9 +769,9 @@ class Trainer: return None # TODO: Fix inference on WaveGrad if hasattr(self.eval_loader.dataset, "load_test_samples"): samples = self.eval_loader.dataset.load_test_samples(1) - figures, audios = self.model.test_run(self.ap, samples, None, self.use_cuda) + figures, audios = self.model.test_run(self.ap, samples, None) else: - figures, audios = self.model.test_run(self.ap, self.use_cuda) + figures, audios = self.model.test_run(self.ap) self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_figures(self.total_steps_done, figures) return None diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index a30c5f02..561b76fb 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -200,7 +200,7 @@ class BaseTTS(BaseModel): ) return loader - def test_run(self, ap, use_cuda) -> Tuple[Dict, Dict]: + def test_run(self, ap) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. @@ -218,7 +218,7 @@ class BaseTTS(BaseModel): self, sen, self.config, - use_cuda, + "cuda" in str(next(self.parameters()).device), ap, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 01b47a20..9249f81c 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -261,9 +261,7 @@ class Wavegrad(BaseModel): def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: return None, None - def test_run( - self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict, use_cuda - ): # pylint: disable=unused-argument + def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument # setup noise schedule and inference noise_schedule = self.config["test_noise_schedule"] betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) @@ -271,8 +269,7 @@ class Wavegrad(BaseModel): for sample in samples: sample = self.format_batch(sample) x = sample["input"] - if use_cuda: - x = x.cuda() + x = x.to(next(self.parameters()).device) y = sample["waveform"] # compute voice y_pred = self.inference(x) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 90eee58e..c2e47120 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -322,7 +322,7 @@ class Wavernn(BaseVocoder): with torch.no_grad(): if isinstance(mels, np.ndarray): - mels = torch.FloatTensor(mels) + mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device)) if mels.ndim == 2: mels = mels.unsqueeze(0) @@ -571,14 +571,13 @@ class Wavernn(BaseVocoder): @torch.no_grad() def test_run( - self, ap: AudioProcessor, samples: List[Dict], output: Dict, use_cuda # pylint: disable=unused-argument + self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument ) -> Tuple[Dict, Dict]: figures = {} audios = {} for idx, sample in enumerate(samples): x = torch.FloatTensor(sample[0]) - if use_cuda: - x = x.cuda() + x = x.to(next(self.parameters()).device) y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples) x_hat = ap.melspectrogram(y_hat) figures.update(