From f79bbbbd00ff16b70bd087fb79b26323f7dc8358 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 19 Oct 2020 17:56:14 +0200 Subject: [PATCH] use Adam for wavegras instead of RAdam --- TTS/bin/train_wavegrad.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/TTS/bin/train_wavegrad.py b/TTS/bin/train_wavegrad.py index e167a4cb..04af1595 100644 --- a/TTS/bin/train_wavegrad.py +++ b/TTS/bin/train_wavegrad.py @@ -6,15 +6,19 @@ import time import traceback import torch +# DISTRIBUTED +from apex.parallel import DistributedDataParallel as DDP_apex +from torch.nn.parallel import DistributedDataParallel as DDP_th +from torch.optim import Adam from torch.utils.data import DataLoader - +from torch.utils.data.distributed import DistributedSampler from TTS.utils.audio import AudioProcessor from TTS.utils.console_logger import ConsoleLogger +from TTS.utils.distribute import init_distributed from TTS.utils.generic_utils import (KeepAverage, count_parameters, create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) from TTS.utils.io import copy_config_file, load_config -from TTS.utils.radam import RAdam from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.training import setup_torch_training_env from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data @@ -22,13 +26,6 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset from TTS.vocoder.utils.generic_utils import plot_results, setup_generator from TTS.vocoder.utils.io import save_best_model, save_checkpoint -# DISTRIBUTED -from apex.parallel import DistributedDataParallel as DDP_apex -from torch.nn.parallel import DistributedDataParallel as DDP_th -from torch.utils.data.distributed import DistributedSampler -from TTS.utils.distribute import init_distributed - - use_cuda, num_gpus = setup_torch_training_env(True, True) @@ -329,7 +326,7 @@ def main(args): # pylint: disable=redefined-outer-name model = setup_generator(c) # setup optimizers - optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0) + optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0) # DISTRIBUTED if c.apex_amp_level is not None: