Fix test sentences synthesis

This commit is contained in:
WeberJulian 2021-07-13 16:04:42 +02:00
parent 93a74cbb71
commit 32974dd6a9
2 changed files with 9 additions and 9 deletions

View File

@ -764,11 +764,11 @@ class Trainer:
"""Run test and log the results. Test run must be defined by the model. """Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard.""" Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"): if hasattr(self.model, "test_run"):
if hasattr(self.eval_loader.load_test_samples): if hasattr(self.eval_loader, "load_test_samples"):
samples = self.eval_loader.load_test_samples(1) samples = self.eval_loader.load_test_samples(1)
figures, audios = self.model.test_run(samples) figures, audios = self.model.test_run(samples)
else: else:
figures, audios = self.model.test_run() figures, audios = self.model.test_run(use_cuda=self.use_cuda, ap=self.ap)
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, figures) self.tb_logger.tb_test_figures(self.total_steps_done, figures)
@ -790,7 +790,7 @@ class Trainer:
self.train_epoch() self.train_epoch()
if self.config.run_eval: if self.config.run_eval:
self.eval_epoch() self.eval_epoch()
if epoch >= self.config.test_delay_epochs and self.args.rank < 0: if epoch >= self.config.test_delay_epochs and self.args.rank <= 0:
self.test_run() self.test_run()
self.c_logger.print_epoch_end( self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values

View File

@ -200,7 +200,7 @@ class BaseTTS(BaseModel):
) )
return loader return loader
def test_run(self) -> Tuple[Dict, Dict]: def test_run(self, use_cuda=True, ap=None) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`. """Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour. You can override this for a different behaviour.
@ -212,14 +212,14 @@ class BaseTTS(BaseModel):
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
aux_inputs = self._get_aux_inputs() aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences): for idx, sen in enumerate(test_sentences):
wav, alignment, model_outputs, _ = synthesis( wav, alignment, model_outputs, _ = synthesis(
self.model, self,
sen, sen,
self.config, self.config,
self.use_cuda, use_cuda,
self.ap, ap,
speaker_id=aux_inputs["speaker_id"], speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"], d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"], style_wav=aux_inputs["style_wav"],
@ -229,6 +229,6 @@ class BaseTTS(BaseModel):
).values() ).values()
test_audios["{}-audio".format(idx)] = wav test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False) test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
return test_figures, test_audios return test_figures, test_audios