mirror of https://github.com/coqui-ai/TTS.git
refactor partial reinit script as a function. Allow user to select layers to reinit in finutunning
This commit is contained in:
parent
06a7aeb26d
commit
d8908692c5
|
@ -29,6 +29,8 @@
|
||||||
"url": "tcp:\/\/localhost:54321"
|
"url": "tcp:\/\/localhost:54321"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"reinit_layers": ["model.decoder.attention_layer"],
|
||||||
|
|
||||||
"model": "Tacotron2", // one of the model in models/
|
"model": "Tacotron2", // one of the model in models/
|
||||||
"grad_clip": 0.02, // upper limit for gradients for clipping.
|
"grad_clip": 0.02, // upper limit for gradients for clipping.
|
||||||
"epochs": 1000, // total number of epochs to train.
|
"epochs": 1000, // total number of epochs to train.
|
||||||
|
|
|
@ -29,6 +29,8 @@
|
||||||
"url": "tcp:\/\/localhost:54321"
|
"url": "tcp:\/\/localhost:54321"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"reinit_layers": ["model.decoder.attention_layer"],
|
||||||
|
|
||||||
"model": "Tacotron2", // one of the model in models/
|
"model": "Tacotron2", // one of the model in models/
|
||||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||||
"epochs": 1000, // total number of epochs to train.
|
"epochs": 1000, // total number of epochs to train.
|
||||||
|
|
22
train.py
22
train.py
|
@ -22,7 +22,8 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
||||||
create_experiment_folder, get_commit_hash,
|
create_experiment_folder, get_commit_hash,
|
||||||
load_config, lr_decay,
|
load_config, lr_decay,
|
||||||
remove_experiment_folder, save_best_model,
|
remove_experiment_folder, save_best_model,
|
||||||
save_checkpoint, sequence_mask, weight_decay)
|
save_checkpoint, sequence_mask, weight_decay,
|
||||||
|
set_init_dict)
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
from utils.synthesis import synthesis
|
from utils.synthesis import synthesis
|
||||||
from utils.text.symbols import phonemes, symbols
|
from utils.text.symbols import phonemes, symbols
|
||||||
|
@ -396,24 +397,9 @@ def main(args):
|
||||||
print(" > Partial model initialization.")
|
print(" > Partial model initialization.")
|
||||||
partial_init_flag = True
|
partial_init_flag = True
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
model_dict = set_init_dict(model_dict, checkpoint, c)
|
||||||
# 1. filter out unnecessary keys
|
|
||||||
pretrained_dict = {
|
|
||||||
k: v
|
|
||||||
for k, v in checkpoint['model'].items() if k in model_dict
|
|
||||||
}
|
|
||||||
# 2. filter out different size layers
|
|
||||||
pretrained_dict = {
|
|
||||||
k: v
|
|
||||||
for k, v in pretrained_dict.items()
|
|
||||||
if v.numel() == model_dict[k].numel()
|
|
||||||
}
|
|
||||||
# 3. overwrite entries in the existing state dict
|
|
||||||
model_dict.update(pretrained_dict)
|
|
||||||
# 4. load the new state dict
|
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
print(" | > {} / {} layers are initialized".format(
|
del model_dict
|
||||||
len(pretrained_dict), len(model_dict)))
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
|
|
|
@ -195,3 +195,30 @@ def sequence_mask(sequence_length, max_len=None):
|
||||||
.expand_as(seq_range_expand))
|
.expand_as(seq_range_expand))
|
||||||
# B x T_max
|
# B x T_max
|
||||||
return seq_range_expand < seq_length_expand
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
|
|
||||||
|
def set_init_dict(model_dict, checkpoint, c):
|
||||||
|
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||||
|
# 1. filter out unnecessary keys
|
||||||
|
pretrained_dict = {
|
||||||
|
k: v
|
||||||
|
for k, v in checkpoint['model'].items() if k in model_dict
|
||||||
|
}
|
||||||
|
# 2. filter out different size layers
|
||||||
|
pretrained_dict = {
|
||||||
|
k: v
|
||||||
|
for k, v in pretrained_dict.items()
|
||||||
|
if v.numel() == model_dict[k].numel()
|
||||||
|
}
|
||||||
|
# 3. skip reinit layers
|
||||||
|
if c.reinit_layers is not None:
|
||||||
|
for reinit_layer_name in c.reinit_layers:
|
||||||
|
pretrained_dict = {
|
||||||
|
k: v
|
||||||
|
for k, v in pretrained_dict.items()
|
||||||
|
if reinit_layer_name not in k
|
||||||
|
}
|
||||||
|
# 4. overwrite entries in the existing state dict
|
||||||
|
model_dict.update(pretrained_dict)
|
||||||
|
print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict)))
|
||||||
|
return model_dict
|
Loading…
Reference in New Issue