mirror of https://github.com/coqui-ai/TTS.git
use Adam for wavegras instead of RAdam
This commit is contained in:
parent
7bcdb7ac35
commit
f79bbbbd00
|
@ -6,15 +6,19 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
# DISTRIBUTED
|
||||||
|
from apex.parallel import DistributedDataParallel as DDP_apex
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
|
from torch.optim import Adam
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
|
from TTS.utils.distribute import init_distributed
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
create_experiment_folder, get_git_branch,
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict)
|
||||||
from TTS.utils.io import copy_config_file, load_config
|
from TTS.utils.io import copy_config_file, load_config
|
||||||
from TTS.utils.radam import RAdam
|
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.training import setup_torch_training_env
|
from TTS.utils.training import setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
|
@ -22,13 +26,6 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
|
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
|
||||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||||
|
|
||||||
# DISTRIBUTED
|
|
||||||
from apex.parallel import DistributedDataParallel as DDP_apex
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from TTS.utils.distribute import init_distributed
|
|
||||||
|
|
||||||
|
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -329,7 +326,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model = setup_generator(c)
|
model = setup_generator(c)
|
||||||
|
|
||||||
# setup optimizers
|
# setup optimizers
|
||||||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0)
|
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||||
|
|
||||||
# DISTRIBUTED
|
# DISTRIBUTED
|
||||||
if c.apex_amp_level is not None:
|
if c.apex_amp_level is not None:
|
||||||
|
|
Loading…
Reference in New Issue