mirror of https://github.com/coqui-ai/TTS.git
restore schedulers only if training is continuing a previous training
inherit nn.Module for TorchSTFT
This commit is contained in:
parent
2c71c6d8cd
commit
5b70da2e3f
|
@ -515,17 +515,19 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model_disc.load_state_dict(checkpoint['model_disc'])
|
model_disc.load_state_dict(checkpoint['model_disc'])
|
||||||
print(" > Restoring Discriminator Optimizer...")
|
print(" > Restoring Discriminator Optimizer...")
|
||||||
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
||||||
if 'scheduler' in checkpoint and scheduler_gen is not None:
|
# restore schedulers if it is a continuing training.
|
||||||
print(" > Restoring Generator LR Scheduler...")
|
if args.continue_path != '':
|
||||||
scheduler_gen.load_state_dict(checkpoint['scheduler'])
|
if 'scheduler' in checkpoint and scheduler_gen is not None:
|
||||||
# NOTE: Not sure if necessary
|
print(" > Restoring Generator LR Scheduler...")
|
||||||
scheduler_gen.optimizer = optimizer_gen
|
scheduler_gen.load_state_dict(checkpoint['scheduler'])
|
||||||
if 'scheduler_disc' in checkpoint and scheduler_disc is not None:
|
# NOTE: Not sure if necessary
|
||||||
print(" > Restoring Discriminator LR Scheduler...")
|
scheduler_gen.optimizer = optimizer_gen
|
||||||
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
|
if 'scheduler_disc' in checkpoint and scheduler_disc is not None:
|
||||||
scheduler_disc.optimizer = optimizer_disc
|
print(" > Restoring Discriminator LR Scheduler...")
|
||||||
if c.lr_scheduler_disc == "ExponentialLR":
|
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
|
||||||
scheduler_disc.last_epoch = checkpoint['epoch']
|
scheduler_disc.optimizer = optimizer_disc
|
||||||
|
if c.lr_scheduler_disc == "ExponentialLR":
|
||||||
|
scheduler_disc.last_epoch = checkpoint['epoch']
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# restore only matching layers.
|
# restore only matching layers.
|
||||||
print(" > Partial model initialization...")
|
print(" > Partial model initialization...")
|
||||||
|
|
|
@ -4,7 +4,7 @@ from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
class TorchSTFT():
|
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
"""TODO: Merge this with audio.py"""
|
"""TODO: Merge this with audio.py"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
n_fft,
|
n_fft,
|
||||||
|
@ -34,6 +34,7 @@ class TorchSTFT():
|
||||||
if use_mel:
|
if use_mel:
|
||||||
self._build_mel_basis()
|
self._build_mel_basis()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
"""Compute spectrogram frames by torch based stft.
|
"""Compute spectrogram frames by torch based stft.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue