mirror of https://github.com/coqui-ai/TTS.git
use Adam for wavegras instead of RAdam
This commit is contained in:
parent
7bcdb7ac35
commit
f79bbbbd00
|
@ -6,15 +6,19 @@ import time
|
|||
import traceback
|
||||
|
||||
import torch
|
||||
# DISTRIBUTED
|
||||
from apex.parallel import DistributedDataParallel as DDP_apex
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
|
@ -22,13 +26,6 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
|||
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
||||
# DISTRIBUTED
|
||||
from apex.parallel import DistributedDataParallel as DDP_apex
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from TTS.utils.distribute import init_distributed
|
||||
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||
|
||||
|
||||
|
@ -329,7 +326,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model = setup_generator(c)
|
||||
|
||||
# setup optimizers
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
||||
# DISTRIBUTED
|
||||
if c.apex_amp_level is not None:
|
||||
|
|
Loading…
Reference in New Issue