diff --git a/TTS/trainer.py b/TTS/trainer.py index ec6d4417..f628d9a4 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -36,7 +36,7 @@ from TTS.utils.generic_utils import ( ) from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint from TTS.utils.logging import ConsoleLogger, TensorboardLogger -from TTS.utils.trainer_utils import * +from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model as setup_vocoder_model diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 58d6532e..94583147 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -144,20 +144,24 @@ class GAN(BaseVocoder): return outputs, loss_dict - def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + @staticmethod + def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: y_hat = outputs[0]["model_outputs"] y = batch["waveform"] - figures = plot_results(y_hat, y, ap, "train") + figures = plot_results(y_hat, y, ap, name) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() - audios = {"train/audio": sample_voice} + audios = {f"{name}/audio": sample_voice} return figures, audios + def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + return self._log("train", ap, batch, outputs) + @torch.no_grad() def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: return self.train_step(batch, criterion, optimizer_idx) def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: - return self.train_log(ap, batch, outputs) + return self._log("eval", ap, batch, outputs) def load_checkpoint( self,