mirror of https://github.com/coqui-ai/TTS.git
move scheduler updates to the end of the epoch
This commit is contained in:
parent
2a872c98aa
commit
3fb78c004a
|
@ -14,9 +14,8 @@ from TTS.utils.arguments import parse_arguments, process_args
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict)
|
||||||
|
from TTS.utils.io import copy_model_files, load_config
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
|
|
||||||
from TTS.utils.training import setup_torch_training_env
|
from TTS.utils.training import setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
|
@ -161,8 +160,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
|
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
|
||||||
c.gen_clip_grad)
|
c.gen_clip_grad)
|
||||||
optimizer_G.step()
|
optimizer_G.step()
|
||||||
if scheduler_G is not None:
|
|
||||||
scheduler_G.step()
|
|
||||||
|
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
for key, value in loss_G_dict.items():
|
for key, value in loss_G_dict.items():
|
||||||
|
@ -221,8 +218,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
|
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
|
||||||
c.disc_clip_grad)
|
c.disc_clip_grad)
|
||||||
optimizer_D.step()
|
optimizer_D.step()
|
||||||
if scheduler_D is not None:
|
|
||||||
scheduler_D.step()
|
|
||||||
|
|
||||||
for key, value in loss_D_dict.items():
|
for key, value in loss_D_dict.items():
|
||||||
if isinstance(value, (int, float)):
|
if isinstance(value, (int, float)):
|
||||||
|
@ -293,7 +288,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
{'train/audio': sample_voice},
|
{'train/audio': sample_voice},
|
||||||
c.audio["sample_rate"])
|
c.audio["sample_rate"])
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
if scheduler_G is not None:
|
||||||
|
scheduler_G.step()
|
||||||
|
|
||||||
|
if scheduler_D is not None:
|
||||||
|
scheduler_D.step()
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||||
|
|
Loading…
Reference in New Issue