From 22d62aee5b577e6b006f174bd40b6dd719028784 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 11 Dec 2018 17:53:08 +0100 Subject: [PATCH 1/2] partial model initialization --- train.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/train.py b/train.py index 54d6140f..8fe07ded 100644 --- a/train.py +++ b/train.py @@ -401,6 +401,16 @@ def main(args): if args.restore_path: checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + # 1. filter out unnecessary keys + pretrained_dict = { + k: v + for k, v in checkpoint['model'].items() if k in model_dict + } + # 2. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + # 3. load the new state dict + model.load_state_dict(model_dict) if use_cuda: model = model.cuda() criterion.cuda() From 703be0499342c650aa09543107b26f4683af7eff Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 12 Dec 2018 12:02:10 +0100 Subject: [PATCH 2/2] bug fix --- layers/tacotron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index 3e82486f..b972ed4a 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -338,8 +338,8 @@ class Decoder(nn.Module): gain=torch.nn.init.calculate_gain('linear')) def _reshape_memory(self, memory): - B = memory.shape[0] if memory is not None: + B = memory.shape[0] # Grouping multiple frames if necessary if memory.size(-1) == self.memory_dim: memory = memory.contiguous()