restore schedulers only if training is continuing a previous training

inherit nn.Module for TorchSTFT
This commit is contained in:
Eren Gölge 2021-04-09 19:31:28 +02:00
parent 2c71c6d8cd
commit 5b70da2e3f
2 changed files with 15 additions and 12 deletions

View File

@ -515,17 +515,19 @@ def main(args): # pylint: disable=redefined-outer-name
model_disc.load_state_dict(checkpoint['model_disc']) model_disc.load_state_dict(checkpoint['model_disc'])
print(" > Restoring Discriminator Optimizer...") print(" > Restoring Discriminator Optimizer...")
optimizer_disc.load_state_dict(checkpoint['optimizer_disc']) optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
if 'scheduler' in checkpoint and scheduler_gen is not None: # restore schedulers if it is a continuing training.
print(" > Restoring Generator LR Scheduler...") if args.continue_path != '':
scheduler_gen.load_state_dict(checkpoint['scheduler']) if 'scheduler' in checkpoint and scheduler_gen is not None:
# NOTE: Not sure if necessary print(" > Restoring Generator LR Scheduler...")
scheduler_gen.optimizer = optimizer_gen scheduler_gen.load_state_dict(checkpoint['scheduler'])
if 'scheduler_disc' in checkpoint and scheduler_disc is not None: # NOTE: Not sure if necessary
print(" > Restoring Discriminator LR Scheduler...") scheduler_gen.optimizer = optimizer_gen
scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) if 'scheduler_disc' in checkpoint and scheduler_disc is not None:
scheduler_disc.optimizer = optimizer_disc print(" > Restoring Discriminator LR Scheduler...")
if c.lr_scheduler_disc == "ExponentialLR": scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
scheduler_disc.last_epoch = checkpoint['epoch'] scheduler_disc.optimizer = optimizer_disc
if c.lr_scheduler_disc == "ExponentialLR":
scheduler_disc.last_epoch = checkpoint['epoch']
except RuntimeError: except RuntimeError:
# restore only matching layers. # restore only matching layers.
print(" > Partial model initialization...") print(" > Partial model initialization...")

View File

@ -4,7 +4,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
class TorchSTFT(): class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""TODO: Merge this with audio.py""" """TODO: Merge this with audio.py"""
def __init__(self, def __init__(self,
n_fft, n_fft,
@ -34,6 +34,7 @@ class TorchSTFT():
if use_mel: if use_mel:
self._build_mel_basis() self._build_mel_basis()
@torch.no_grad()
def __call__(self, x): def __call__(self, x):
"""Compute spectrogram frames by torch based stft. """Compute spectrogram frames by torch based stft.