mirror of https://github.com/coqui-ai/TTS.git
Fix `eval_log` for `gan.py` 🛠️
This commit is contained in:
parent
d700845b10
commit
cfa5041db7
|
@ -36,7 +36,7 @@ from TTS.utils.generic_utils import (
|
|||
)
|
||||
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||
from TTS.utils.trainer_utils import *
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
|
||||
|
|
|
@ -144,20 +144,24 @@ class GAN(BaseVocoder):
|
|||
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
@staticmethod
|
||||
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
y_hat = outputs[0]["model_outputs"]
|
||||
y = batch["waveform"]
|
||||
figures = plot_results(y_hat, y, ap, "train")
|
||||
figures = plot_results(y_hat, y, ap, name)
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
audios = {"train/audio": sample_voice}
|
||||
audios = {f"{name}/audio": sample_voice}
|
||||
return figures, audios
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
return self._log("train", ap, batch, outputs)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
return self.train_log(ap, batch, outputs)
|
||||
return self._log("eval", ap, batch, outputs)
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue