diff --git a/train.py b/train.py index 730d7389..591cbd76 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,7 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters, load_config, remove_experiment_folder, save_best_model, save_checkpoint, weight_decay, set_init_dict, copy_config_file, setup_model, - split_dataset) + split_dataset, gradual_training_scheduler) from utils.logger import Logger from utils.speakers import load_speaker_mapping, save_speaker_mapping, \ get_speakers @@ -81,20 +81,6 @@ def setup_loader(ap, is_val=False, verbose=False): return loader -def gradual_training_scheduler(global_step): - if global_step < 10000: - r, batch_size = 7, 32 - elif global_step < 50000: - r, batch_size = 5, 32 - elif global_step < 130000: - r, batch_size = 3, 32 - elif global_step < 290000: - r, batch_size = 2, 16 - else: - r, batch_size = 1, 16 - return r, batch_size - - def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, ap, epoch): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) @@ -106,8 +92,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, avg_decoder_loss = 0 avg_stop_loss = 0 avg_step_time = 0 + avg_loader_time = 0 print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) + end_time = time.time() for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -121,6 +109,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, stop_targets = data[6] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) + loader_time = time.time() - end_time if c.use_speaker_embedding: speaker_ids = [speaker_mapping[speaker_name] @@ -191,17 +180,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, else: grad_norm_st = 0 - step_time = time.time() - start_time - epoch_time += step_time - if current_step % c.print_step == 0: print( " | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} " "DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " - "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}".format( + "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} " + "LoaderTime:{:.2f} LR:{:.6f}".format( num_iter, batch_n_iter, current_step, loss.item(), postnet_loss.item(), decoder_loss.item(), stop_loss.item(), - grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr), + grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, + loader_time, current_lr), flush=True) # aggregate losses from processes @@ -216,6 +204,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, avg_decoder_loss += float(decoder_loss.item()) avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item()) avg_step_time += step_time + avg_loader_time += loader_time # Plot Training Iter Stats iter_stats = {"loss_posnet": postnet_loss.item(), @@ -254,11 +243,15 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, {'TrainAudio': train_audio}, c.audio["sample_rate"]) + step_time = end_time - start_time + epoch_time += step_time + avg_postnet_loss /= (num_iter + 1) avg_decoder_loss /= (num_iter + 1) avg_stop_loss /= (num_iter + 1) avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss avg_step_time /= (num_iter + 1) + avg_loader_time /= (num_iter + 1) # print epoch stats print( @@ -267,7 +260,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, "AvgStopLoss:{:.5f} EpochTime:{:.2f} " "AvgStepTime:{:.2f}".format(current_step, avg_total_loss, avg_postnet_loss, avg_decoder_loss, - avg_stop_loss, epoch_time, avg_step_time), + avg_stop_loss, epoch_time, avg_step_time, + avg_loader_time), flush=True) # Plot Epoch Stats @@ -281,6 +275,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, if c.tb_model_param_stats: tb_logger.tb_model_weights(model, current_step) + end_time = time.time() return avg_postnet_loss, current_step @@ -541,9 +536,10 @@ def main(args): #pylint: disable=redefined-outer-name current_step = 0 for epoch in range(0, c.epochs): # set gradual training - r, c.batch_size = gradual_training_scheduler(current_step) - c.r = r - model.decoder._set_r(r) + if c.gradual_training is not None: + r, c.batch_size = gradual_training_scheduler(current_step, c) + c.r = r + model.decoder._set_r(r) print(" > Number of outputs per iteration:", model.decoder.r) train_loss, current_step = train(model, criterion, criterion_st, @@ -592,7 +588,7 @@ if __name__ == '__main__': '--output_folder', type=str, default='', - help='folder name for traning outputs.' + help='folder name for training outputs.' ) # DISTRUBUTED diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 64414765..8a64dbae 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -305,3 +305,10 @@ def split_dataset(items): else: return items[:eval_split_size], items[eval_split_size:] + +def gradual_training_scheduler(global_step, config): + new_values = None + for values in config.gradual_training: + if global_step >= values[0]: + new_values = values + return new_values[1], new_values[2] \ No newline at end of file