use Adam for wavegras instead of RAdam

This commit is contained in:
erogol 2020-10-19 17:56:14 +02:00
parent 7bcdb7ac35
commit f79bbbbd00
1 changed files with 7 additions and 10 deletions

View File

@ -6,15 +6,19 @@ import time
import traceback
import torch
# DISTRIBUTED
from apex.parallel import DistributedDataParallel as DDP_apex
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
@ -22,13 +26,6 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
# DISTRIBUTED
from apex.parallel import DistributedDataParallel as DDP_apex
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.distribute import init_distributed
use_cuda, num_gpus = setup_torch_training_env(True, True)
@ -329,7 +326,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = setup_generator(c)
# setup optimizers
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0)
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
# DISTRIBUTED
if c.apex_amp_level is not None: