mirror of https://github.com/coqui-ai/TTS.git
Fix test sentences synthesis
This commit is contained in:
parent
93a74cbb71
commit
32974dd6a9
|
@ -764,11 +764,11 @@ class Trainer:
|
||||||
"""Run test and log the results. Test run must be defined by the model.
|
"""Run test and log the results. Test run must be defined by the model.
|
||||||
Model must return figures and audios to be logged by the Tensorboard."""
|
Model must return figures and audios to be logged by the Tensorboard."""
|
||||||
if hasattr(self.model, "test_run"):
|
if hasattr(self.model, "test_run"):
|
||||||
if hasattr(self.eval_loader.load_test_samples):
|
if hasattr(self.eval_loader, "load_test_samples"):
|
||||||
samples = self.eval_loader.load_test_samples(1)
|
samples = self.eval_loader.load_test_samples(1)
|
||||||
figures, audios = self.model.test_run(samples)
|
figures, audios = self.model.test_run(samples)
|
||||||
else:
|
else:
|
||||||
figures, audios = self.model.test_run()
|
figures, audios = self.model.test_run(use_cuda=self.use_cuda, ap=self.ap)
|
||||||
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||||
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
||||||
|
|
||||||
|
@ -790,7 +790,7 @@ class Trainer:
|
||||||
self.train_epoch()
|
self.train_epoch()
|
||||||
if self.config.run_eval:
|
if self.config.run_eval:
|
||||||
self.eval_epoch()
|
self.eval_epoch()
|
||||||
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
|
if epoch >= self.config.test_delay_epochs and self.args.rank <= 0:
|
||||||
self.test_run()
|
self.test_run()
|
||||||
self.c_logger.print_epoch_end(
|
self.c_logger.print_epoch_end(
|
||||||
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
||||||
|
|
|
@ -200,7 +200,7 @@ class BaseTTS(BaseModel):
|
||||||
)
|
)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
def test_run(self) -> Tuple[Dict, Dict]:
|
def test_run(self, use_cuda=True, ap=None) -> Tuple[Dict, Dict]:
|
||||||
"""Generic test run for `tts` models used by `Trainer`.
|
"""Generic test run for `tts` models used by `Trainer`.
|
||||||
|
|
||||||
You can override this for a different behaviour.
|
You can override this for a different behaviour.
|
||||||
|
@ -212,14 +212,14 @@ class BaseTTS(BaseModel):
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
test_sentences = self.config.test_sentences
|
test_sentences = self.config.test_sentences
|
||||||
aux_inputs = self._get_aux_inputs()
|
aux_inputs = self.get_aux_input()
|
||||||
for idx, sen in enumerate(test_sentences):
|
for idx, sen in enumerate(test_sentences):
|
||||||
wav, alignment, model_outputs, _ = synthesis(
|
wav, alignment, model_outputs, _ = synthesis(
|
||||||
self.model,
|
self,
|
||||||
sen,
|
sen,
|
||||||
self.config,
|
self.config,
|
||||||
self.use_cuda,
|
use_cuda,
|
||||||
self.ap,
|
ap,
|
||||||
speaker_id=aux_inputs["speaker_id"],
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
d_vector=aux_inputs["d_vector"],
|
d_vector=aux_inputs["d_vector"],
|
||||||
style_wav=aux_inputs["style_wav"],
|
style_wav=aux_inputs["style_wav"],
|
||||||
|
@ -229,6 +229,6 @@ class BaseTTS(BaseModel):
|
||||||
).values()
|
).values()
|
||||||
|
|
||||||
test_audios["{}-audio".format(idx)] = wav
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
|
||||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||||
return test_figures, test_audios
|
return test_figures, test_audios
|
||||||
|
|
Loading…
Reference in New Issue