From c8bbcdfd076183e2b59350e8c274502f4e2935e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 13 Aug 2021 19:39:02 +0000 Subject: [PATCH] Fix `test_run` for DDP --- TTS/trainer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index 32e561d6..d3d66ab2 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -829,9 +829,15 @@ class Trainer: if hasattr(self.eval_loader.dataset, "load_test_samples"): samples = self.eval_loader.dataset.load_test_samples(1) - figures, audios = self.model.test_run(self.ap, samples, None) + if self.num_gpus > 1: + figures, audios = self.model.module.test_run(self.ap, samples, None) + else: + figures, audios = self.model.test_run(self.ap, samples, None) else: - figures, audios = self.model.test_run(self.ap) + if self.num_gpus > 1: + figures, audios = self.model.module.test_run(self.ap) + else: + figures, audios = self.model.test_run(self.ap) self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.dashboard_logger.test_figures(self.total_steps_done, figures)