From 5b70da2e3f711221232ee596349b7f32a0b30ae3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 9 Apr 2021 19:31:28 +0200 Subject: [PATCH] restore schedulers only if training is continuing a previous training inherit nn.Module for TorchSTFT --- TTS/bin/train_vocoder_gan.py | 24 +++++++++++++----------- TTS/vocoder/layers/losses.py | 3 ++- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 7681d660..6357bc01 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -515,17 +515,19 @@ def main(args): # pylint: disable=redefined-outer-name model_disc.load_state_dict(checkpoint['model_disc']) print(" > Restoring Discriminator Optimizer...") optimizer_disc.load_state_dict(checkpoint['optimizer_disc']) - if 'scheduler' in checkpoint and scheduler_gen is not None: - print(" > Restoring Generator LR Scheduler...") - scheduler_gen.load_state_dict(checkpoint['scheduler']) - # NOTE: Not sure if necessary - scheduler_gen.optimizer = optimizer_gen - if 'scheduler_disc' in checkpoint and scheduler_disc is not None: - print(" > Restoring Discriminator LR Scheduler...") - scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) - scheduler_disc.optimizer = optimizer_disc - if c.lr_scheduler_disc == "ExponentialLR": - scheduler_disc.last_epoch = checkpoint['epoch'] + # restore schedulers if it is a continuing training. + if args.continue_path != '': + if 'scheduler' in checkpoint and scheduler_gen is not None: + print(" > Restoring Generator LR Scheduler...") + scheduler_gen.load_state_dict(checkpoint['scheduler']) + # NOTE: Not sure if necessary + scheduler_gen.optimizer = optimizer_gen + if 'scheduler_disc' in checkpoint and scheduler_disc is not None: + print(" > Restoring Discriminator LR Scheduler...") + scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) + scheduler_disc.optimizer = optimizer_disc + if c.lr_scheduler_disc == "ExponentialLR": + scheduler_disc.last_epoch = checkpoint['epoch'] except RuntimeError: # restore only matching layers. print(" > Partial model initialization...") diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index bbc4bc73..516c62a6 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -4,7 +4,7 @@ from torch import nn from torch.nn import functional as F -class TorchSTFT(): +class TorchSTFT(nn.Module): # pylint: disable=abstract-method """TODO: Merge this with audio.py""" def __init__(self, n_fft, @@ -34,6 +34,7 @@ class TorchSTFT(): if use_mel: self._build_mel_basis() + @torch.no_grad() def __call__(self, x): """Compute spectrogram frames by torch based stft.