Fix `eval_log` for `gan.py` 🛠️

This commit is contained in:
Eren Gölge 2021-06-21 16:51:28 +02:00
parent d700845b10
commit cfa5041db7
2 changed files with 9 additions and 5 deletions

View File

@ -36,7 +36,7 @@ from TTS.utils.generic_utils import (
)
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.trainer_utils import *
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model

View File

@ -144,20 +144,24 @@ class GAN(BaseVocoder):
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
@staticmethod
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
y_hat = outputs[0]["model_outputs"]
y = batch["waveform"]
figures = plot_results(y_hat, y, ap, "train")
figures = plot_results(y_hat, y, ap, name)
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
audios = {"train/audio": sample_voice}
audios = {f"{name}/audio": sample_voice}
return figures, audios
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
return self._log("train", ap, batch, outputs)
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
return self.train_log(ap, batch, outputs)
return self._log("eval", ap, batch, outputs)
def load_checkpoint(
self,