fix distributed training for train_* scripts

This commit is contained in:
erogol 2020-10-16 17:53:05 +02:00
parent 193b81b273
commit a1582a0e12
7 changed files with 99 additions and 82 deletions

View File

@ -19,13 +19,16 @@ from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
# from distribute import (DistributedSampler, apply_gradient_allreduce,
# init_distributed, reduce_tensor)
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator,
setup_generator)
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.distribute import init_distributed
use_cuda, num_gpus = setup_torch_training_env(True, True)
@ -45,12 +48,12 @@ def setup_loader(ap, is_val=False, verbose=False):
use_cache=c.use_cache,
verbose=verbose)
dataset.shuffle_mapping()
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader(dataset,
batch_size=1 if is_val else c.batch_size,
shuffle=True,
shuffle=False if num_gpus > 1 else True,
drop_last=False,
sampler=None,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
@ -243,41 +246,42 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
# plot step stats
if global_step % 10 == 0:
iter_stats = {
"lr_G": current_lr_G,
"lr_D": current_lr_D,
"step_time": step_time
}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
if args.rank == 0:
# plot step stats
if global_step % 10 == 0:
iter_stats = {
"lr_G": current_lr_G,
"lr_D": current_lr_D,
"step_time": step_time
}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
# save checkpoint
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model_G,
optimizer_G,
scheduler_G,
model_D,
optimizer_D,
scheduler_D,
global_step,
epoch,
OUT_PATH,
model_losses=loss_dict)
# save checkpoint
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model_G,
optimizer_G,
scheduler_G,
model_D,
optimizer_D,
scheduler_D,
global_step,
epoch,
OUT_PATH,
model_losses=loss_dict)
# compute spectrograms
figures = plot_results(y_hat_vis, y_G, ap, global_step,
'train')
tb_logger.tb_train_figures(global_step, figures)
# compute spectrograms
figures = plot_results(y_hat_vis, y_G, ap, global_step,
'train')
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_train_audios(global_step,
{'train/audio': sample_voice},
c.audio["sample_rate"])
# Sample audio
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_train_audios(global_step,
{'train/audio': sample_voice},
c.audio["sample_rate"])
end_time = time.time()
# print epoch stats
@ -286,7 +290,8 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
# Plot Training Epoch Stats
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(keep_avg.avg_values)
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
if args.rank == 0:
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
# TODO: plot model stats
# if c.tb_model_param_stats:
# tb_logger.tb_model_weights(model, global_step)
@ -417,20 +422,21 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
if c.print_eval:
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
# compute spectrograms
figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
tb_logger.tb_eval_figures(global_step, figures)
if args.rank == 0:
# compute spectrograms
figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
tb_logger.tb_eval_figures(global_step, figures)
# Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
c.audio["sample_rate"])
# Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
c.audio["sample_rate"])
# synthesize a full voice
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
# synthesize a full voice
data_loader.return_segments = False
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
return keep_avg.avg_values
@ -450,9 +456,9 @@ def main(args): # pylint: disable=redefined-outer-name
ap = AudioProcessor(**c.audio)
# DISTRUBUTED
# if num_gpus > 1:
# init_distributed(args.rank, num_gpus, args.group_id,
# c.distributed["backend"], c.distributed["url"])
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
# setup models
model_gen = setup_generator(c)
@ -532,8 +538,9 @@ def main(args): # pylint: disable=redefined-outer-name
criterion_disc.cuda()
# DISTRUBUTED
# if num_gpus > 1:
# model = apply_gradient_allreduce(model)
if num_gpus > 1:
model_gen = DDP_th(model_gen, device_ids=[args.rank])
model_disc = DDP_th(model_disc, device_ids=[args.rank])
num_params = count_parameters(model_gen)
print(" > Generator has {} parameters".format(num_params), flush=True)

View File

@ -11,6 +11,7 @@ import traceback
import torch
from random import randrange
from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import GlowTTSLoss
@ -34,6 +35,13 @@ from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import (NoamLR, check_update,
setup_torch_training_env)
# DISTRIBUTED
from apex.parallel import DistributedDataParallel as DDP_apex
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.distribute import init_distributed, reduce_tensor
use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
@ -481,10 +489,9 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
criterion = GlowTTSLoss()
if c.apex_amp_level:
if c.apex_amp_level is not None:
# pylint: disable=import-outside-toplevel
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
model.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else:
@ -523,7 +530,10 @@ def main(args): # pylint: disable=redefined-outer-name
# DISTRUBUTED
if num_gpus > 1:
model = DDP(model)
if c.apex_amp_level is not None:
model = DDP_apex(model)
else:
model = DDP_th(model, device_ids=[args.rank])
if c.noam_schedule:
scheduler = NoamLR(optimizer,

View File

@ -38,8 +38,10 @@ from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
gradual_training_scheduler, set_weight_decay,
setup_torch_training_env)
use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
if is_val and not c.run_eval:
loader = None

View File

@ -4,12 +4,9 @@ import os
import sys
import time
import traceback
from inspect import signature
import torch
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
@ -20,14 +17,18 @@ from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.utils.distribute import init_distributed, reduce_tensor
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator,
setup_generator)
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
# DISTRIBUTED
from apex.parallel import DistributedDataParallel as DDP_apex
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.distribute import init_distributed
use_cuda, num_gpus = setup_torch_training_env(True, True)
@ -111,11 +112,6 @@ def train(model, criterion, optimizer,
else:
loss.backward()
if amp:
amp_opt_params = amp.master_params(optimizer)
else:
amp_opt_params = None
if c.clip_grad > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
@ -279,7 +275,6 @@ def evaluate(model, criterion, ap, global_step, epoch):
return keep_avg.avg_values
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global train_data, eval_data
@ -305,10 +300,9 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0)
# DISTRIBUTED
if c.apex_amp_level:
if c.apex_amp_level is not None:
# pylint: disable=import-outside-toplevel
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
model.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else:
@ -363,7 +357,10 @@ def main(args): # pylint: disable=redefined-outer-name
# DISTRUBUTED
if num_gpus > 1:
model = DDP(model)
if c.apex_amp_level is not None:
model = DDP_apex(model)
else:
model = DDP_th(model, device_ids=[args.rank])
num_params = count_parameters(model)
print(" > WaveGrad has {} parameters".format(num_params), flush=True)
@ -447,7 +444,7 @@ if __name__ == '__main__':
_ = os.path.dirname(os.path.realpath(__file__))
# DISTRIBUTED
if c.apex_amp_level:
if c.apex_amp_level is not None:
print(" > apex AMP level: ", c.apex_amp_level)
OUT_PATH = args.continue_path

View File

@ -54,9 +54,10 @@
"add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model.
// DISTRIBUTED TRAINING
"apex_amp_level": null, // APEX amp optimization level. "O1" is currently supported.
"distributed":{
"backend": "nccl",
"url": "tcp:\/\/localhost:54321"
"url": "tcp:\/\/localhost:54323"
},
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.

View File

@ -31,13 +31,13 @@
"symmetric_norm": true, // move normalization to range [-1, 1]
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"stats_path": "/data/rw/home/Data/LibriTTS/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
"stats_path": "/home/erogol/Data/libritts/LibriTTS/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
},
// DISTRIBUTED TRAINING
"distributed":{
"backend": "nccl",
"url": "tcp:\/\/localhost:54321"
"url": "tcp:\/\/localhost:54324"
},
// MODEL PARAMETERS
@ -83,7 +83,7 @@
},
// DATASET
"data_path": "/data5/rw/home/Data/LibriTTS/LibriTTS/train-clean-360/",
"data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/",
"feature_path": null,
"seq_len": 16384,
"pad_short": 2000,
@ -132,7 +132,7 @@
"eval_split_size": 10,
// PATHS
"output_path": "/data4/rw/home/Trainings/LJSpeech/"
"output_path": "/home/erogol/Models/"
}

View File

@ -34,10 +34,10 @@
},
// DISTRIBUTED TRAINING
"apex_amp_level": "O1", // amp optimization level. "O1" is currentl supported.
"apex_amp_level": null, // APEX amp optimization level. "O1" is currently supported.
"distributed":{
"backend": "nccl",
"url": "tcp:\/\/localhost:54321"
"url": "tcp:\/\/localhost:54322"
},
"target_loss": "avg_wavegrad_loss", // loss value to pick the best model to save after each epoch
@ -47,7 +47,7 @@
"model_params":{
"x_conv_channels":32,
"c_conv_channels":768,
"ublock_out_channels": [768, 512, 512, 256, 128],
"ublock_out_channels": [512, 512, 256, 128, 128],
"dblock_out_channels": [128, 128, 256, 512],
"upsample_factors": [4, 4, 4, 2, 2],
"upsample_dilations": [