mirror of https://github.com/coqui-ai/TTS.git
Fix `eval_log` for `gan.py` 🛠️
This commit is contained in:
parent
d700845b10
commit
cfa5041db7
TTS
|
@ -36,7 +36,7 @@ from TTS.utils.generic_utils import (
|
||||||
)
|
)
|
||||||
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint
|
from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint
|
||||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||||
from TTS.utils.trainer_utils import *
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
|
|
||||||
|
|
|
@ -144,20 +144,24 @@ class GAN(BaseVocoder):
|
||||||
|
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
@staticmethod
|
||||||
|
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||||
y_hat = outputs[0]["model_outputs"]
|
y_hat = outputs[0]["model_outputs"]
|
||||||
y = batch["waveform"]
|
y = batch["waveform"]
|
||||||
figures = plot_results(y_hat, y, ap, "train")
|
figures = plot_results(y_hat, y, ap, name)
|
||||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||||
audios = {"train/audio": sample_voice}
|
audios = {f"{name}/audio": sample_voice}
|
||||||
return figures, audios
|
return figures, audios
|
||||||
|
|
||||||
|
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||||
|
return self._log("train", ap, batch, outputs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||||
return self.train_step(batch, criterion, optimizer_idx)
|
return self.train_step(batch, criterion, optimizer_idx)
|
||||||
|
|
||||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||||
return self.train_log(ap, batch, outputs)
|
return self._log("eval", ap, batch, outputs)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue