mirror of https://github.com/coqui-ai/TTS.git
Fixup `trainer.py` 🛠️
This commit is contained in:
parent
418c7d98d5
commit
5b89cb4fec
|
@ -462,11 +462,11 @@ class Trainer:
|
||||||
update_lr_scheduler = True
|
update_lr_scheduler = True
|
||||||
if self.use_amp_scaler:
|
if self.use_amp_scaler:
|
||||||
if self.use_apex:
|
if self.use_apex:
|
||||||
with amp.scale_loss(loss_dict["loss"], self.optimizer) as scaled_loss:
|
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
amp.master_params(self.optimizer),
|
amp.master_params(optimizer),
|
||||||
self.config.grad_clip,
|
grad_clip,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# model optimizer step in mixed precision mode
|
# model optimizer step in mixed precision mode
|
||||||
|
@ -739,6 +739,7 @@ class Trainer:
|
||||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||||
if audios is not None:
|
if audios is not None:
|
||||||
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||||
|
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
||||||
|
|
||||||
def test_run(self) -> None:
|
def test_run(self) -> None:
|
||||||
"""Run test and log the results. Test run must be defined by the model.
|
"""Run test and log the results. Test run must be defined by the model.
|
||||||
|
|
Loading…
Reference in New Issue