From 6f739ea07ae543f5d80f4a145b6126f402b6c7b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 21 Jun 2021 16:51:28 +0200 Subject: [PATCH] =?UTF-8?q?Fix=20`eval=5Flog`=20for=20`gan.py`=20?= =?UTF-8?q?=F0=9F=9B=A0=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TTS/trainer.py | 2 +- TTS/vocoder/models/gan.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index ec6d4417..f628d9a4 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -36,7 +36,7 @@ from TTS.utils.generic_utils import ( ) from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint from TTS.utils.logging import ConsoleLogger, TensorboardLogger -from TTS.utils.trainer_utils import * +from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model as setup_vocoder_model diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 58d6532e..94583147 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -144,20 +144,24 @@ class GAN(BaseVocoder): return outputs, loss_dict - def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + @staticmethod + def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: y_hat = outputs[0]["model_outputs"] y = batch["waveform"] - figures = plot_results(y_hat, y, ap, "train") + figures = plot_results(y_hat, y, ap, name) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() - audios = {"train/audio": sample_voice} + audios = {f"{name}/audio": sample_voice} return figures, audios + def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + return self._log("train", ap, batch, outputs) + @torch.no_grad() def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: return self.train_step(batch, criterion, optimizer_idx) def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: - return self.train_log(ap, batch, outputs) + return self._log("eval", ap, batch, outputs) def load_checkpoint( self,