From d8908692c5270743d8139dd51dd8893e5087a426 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Sat, 23 Mar 2019 17:19:40 +0100 Subject: [PATCH] refactor partial reinit script as a function. Allow user to select layers to reinit in finutunning --- config.json | 2 ++ config_cluster.json | 2 ++ train.py | 22 ++++------------------ utils/generic_utils.py | 27 +++++++++++++++++++++++++++ 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/config.json b/config.json index 138d3436..7b9e2719 100644 --- a/config.json +++ b/config.json @@ -29,6 +29,8 @@ "url": "tcp:\/\/localhost:54321" }, + "reinit_layers": ["model.decoder.attention_layer"], + "model": "Tacotron2", // one of the model in models/ "grad_clip": 0.02, // upper limit for gradients for clipping. "epochs": 1000, // total number of epochs to train. diff --git a/config_cluster.json b/config_cluster.json index 4d330bd5..79a8e47f 100644 --- a/config_cluster.json +++ b/config_cluster.json @@ -29,6 +29,8 @@ "url": "tcp:\/\/localhost:54321" }, + "reinit_layers": ["model.decoder.attention_layer"], + "model": "Tacotron2", // one of the model in models/ "grad_clip": 1, // upper limit for gradients for clipping. "epochs": 1000, // total number of epochs to train. diff --git a/train.py b/train.py index 524bacbf..09d39ead 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,8 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters, create_experiment_folder, get_commit_hash, load_config, lr_decay, remove_experiment_folder, save_best_model, - save_checkpoint, sequence_mask, weight_decay) + save_checkpoint, sequence_mask, weight_decay, + set_init_dict) from utils.logger import Logger from utils.synthesis import synthesis from utils.text.symbols import phonemes, symbols @@ -396,24 +397,9 @@ def main(args): print(" > Partial model initialization.") partial_init_flag = True model_dict = model.state_dict() - # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - # 1. filter out unnecessary keys - pretrained_dict = { - k: v - for k, v in checkpoint['model'].items() if k in model_dict - } - # 2. filter out different size layers - pretrained_dict = { - k: v - for k, v in pretrained_dict.items() - if v.numel() == model_dict[k].numel() - } - # 3. overwrite entries in the existing state dict - model_dict.update(pretrained_dict) - # 4. load the new state dict + model_dict = set_init_dict(model_dict, checkpoint, c) model.load_state_dict(model_dict) - print(" | > {} / {} layers are initialized".format( - len(pretrained_dict), len(model_dict))) + del model_dict if use_cuda: model = model.cuda() criterion.cuda() diff --git a/utils/generic_utils.py b/utils/generic_utils.py index e159eed0..3afc15f3 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -195,3 +195,30 @@ def sequence_mask(sequence_length, max_len=None): .expand_as(seq_range_expand)) # B x T_max return seq_range_expand < seq_length_expand + + +def set_init_dict(model_dict, checkpoint, c): + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + # 1. filter out unnecessary keys + pretrained_dict = { + k: v + for k, v in checkpoint['model'].items() if k in model_dict + } + # 2. filter out different size layers + pretrained_dict = { + k: v + for k, v in pretrained_dict.items() + if v.numel() == model_dict[k].numel() + } + # 3. skip reinit layers + if c.reinit_layers is not None: + for reinit_layer_name in c.reinit_layers: + pretrained_dict = { + k: v + for k, v in pretrained_dict.items() + if reinit_layer_name not in k + } + # 4. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict))) + return model_dict \ No newline at end of file