diff --git a/train.py b/train.py index 23ac7b93..bfe00aee 100644 --- a/train.py +++ b/train.py @@ -16,11 +16,12 @@ from tensorboardX import SummaryWriter from utils.generic_utils import ( synthesis, remove_experiment_folder, create_experiment_folder, save_checkpoint, save_best_model, load_config, lr_decay, count_parameters, - check_update, get_commit_hash) + check_update, get_commit_hash, sequence_mask) from utils.visual import plot_alignment, plot_spectrogram from models.tacotron import Tacotron from layers.losses import L1LossMasked from utils.audio import AudioProcessor +from torch.optim.lr_scheduler import StepLR torch.manual_seed(1) torch.set_num_threads(4) @@ -28,7 +29,7 @@ use_cuda = torch.cuda.is_available() def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, - ap, epoch): + scheduler, ap, epoch): model = model.train() epoch_time = 0 avg_linear_loss = 0 @@ -58,15 +59,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch * len(data_loader) + 1 # setup lr - current_lr = lr_decay(c.lr, current_step, c.warmup_steps) - current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps) - - for params_group in optimizer.param_groups: - params_group['lr'] = current_lr - - for params_group in optimizer_st.param_groups: - params_group['lr'] = current_lr_st - + scheduler.step() optimizer.zero_grad() optimizer_st.zero_grad() @@ -79,9 +72,12 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, linear_input = linear_input.cuda() stop_targets = stop_targets.cuda() + # compute mask for padding + mask = sequence_mask(text_lengths) + # forward pass - mel_output, linear_output, alignments, stop_tokens =\ - model.forward(text_input, mel_input, text_lengths) + mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel( + model, (text_input, mel_input, mask)) # loss computation stop_loss = criterion_st(stop_tokens, stop_targets) @@ -94,7 +90,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # backpass and check the grad norm for spec losses loss.backward(retain_graph=True) - grad_norm, skip_flag = check_update(model, 0.5, 100) + grad_norm, skip_flag = check_update(model, 1) if skip_flag: optimizer.zero_grad() print(" | > Iteration skipped!!", flush=True) @@ -103,8 +99,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # backpass and check the grad norm for stop loss stop_loss.backward() - grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, - 0.5, 100) + grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5) if skip_flag: optimizer_st.zero_grad() print(" | | > Iteration skipped fro stopnet!!") @@ -115,18 +110,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch_time += step_time if current_step % c.print_step == 0: - print(" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} " - "MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " - "GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, - batch_n_iter, - current_step, - loss.item(), - linear_loss.item(), - mel_loss.item(), - stop_loss.item(), - grad_norm.item(), - grad_norm_st.item(), - step_time), flush=True) + print( + " | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} " + "MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " + "GradNormST:{:.5f} StepTime:{:.2f}".format( + num_iter, batch_n_iter, current_step, loss.item(), + linear_loss.item(), mel_loss.item(), stop_loss.item(), + grad_norm.item(), grad_norm_st.item(), step_time), + flush=True) avg_linear_loss += linear_loss.item() avg_mel_loss += mel_loss.item() @@ -184,16 +175,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, avg_step_time /= (num_iter + 1) # print epoch stats - print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " - "AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} " - "AvgStopLoss:{:.5f} EpochTime:{:.2f} " - "AvgStepTime:{:.2f}".format(current_step, - avg_total_loss, - avg_linear_loss, - avg_mel_loss, - avg_stop_loss, - epoch_time, - avg_step_time), flush=True) + print( + " | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " + "AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} " + "AvgStopLoss:{:.5f} EpochTime:{:.2f} " + "AvgStepTime:{:.2f}".format(current_step, avg_total_loss, + avg_linear_loss, avg_mel_loss, + avg_stop_loss, epoch_time, avg_step_time), + flush=True) # Plot Training Epoch Stats tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step) @@ -266,8 +255,10 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): if num_iter % c.print_step == 0: print( " | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} " - "StopLoss: {:.5f} ".format(loss.item(), linear_loss.item(), - mel_loss.item(), stop_loss.item()), + "StopLoss: {:.5f} ".format(loss.item(), + linear_loss.item(), + mel_loss.item(), + stop_loss.item()), flush=True) avg_linear_loss += linear_loss.item() @@ -322,11 +313,11 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): ap.griffin_lim_iters = 60 for idx, test_sentence in enumerate(test_sentences): try: - wav, linear_spec, alignments = synthesis(model, ap, test_sentence, - use_cuda, c.text_cleaner) - wav_name = 'TestSentences/{}'.format(idx) - tb.add_audio( - wav_name, wav, current_step, sample_rate=c.sample_rate) + wav, linear_spec, alignments = synthesis(model, ap, test_sentence, + use_cuda, c.text_cleaner) + wav_name = 'TestSentences/{}'.format(idx) + tb.add_audio( + wav_name, wav, current_step, sample_rate=c.sample_rate) except: print(" !! Error as creating Test Sentence -", idx) pass @@ -405,7 +396,7 @@ def main(args): checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) if use_cuda: - model = nn.DataParallel(model.cuda()) + model = model.cuda() criterion.cuda() criterion_st.cuda() optimizer.load_state_dict(checkpoint['optimizer']) @@ -423,10 +414,11 @@ def main(args): args.restore_step = 0 print("\n > Starting a new training", flush=True) if use_cuda: - model = nn.DataParallel(model.cuda()) + model = model.cuda() criterion.cuda() criterion_st.cuda() + scheduler = StepLR(optimizer, step_size=c.decay_step, gamma=c.lr_decay) num_params = count_parameters(model) print(" | > Model has {} parameters".format(num_params), flush=True) @@ -439,7 +431,7 @@ def main(args): for epoch in range(0, c.epochs): train_loss, current_step = train(model, criterion, criterion_st, train_loader, optimizer, optimizer_st, - ap, epoch) + scheduler, ap, epoch) val_loss = evaluate(model, criterion, criterion_st, val_loader, ap, current_step) print(