mirror of https://github.com/coqui-ai/TTS.git
Remove DataParallel from the model state before saving
This commit is contained in:
parent
066ff1c6cf
commit
6e29d563de
|
@ -213,7 +213,7 @@ class Decoder(nn.Module):
|
|||
r (int): number of outputs per time step.
|
||||
eps (float): threshold for detecting the end of a sentence.
|
||||
"""
|
||||
def __init__(self, in_features, memory_dim, r, eps=0.2):
|
||||
def __init__(self, in_features, memory_dim, r, eps=0.05):
|
||||
super(Decoder, self).__init__()
|
||||
self.max_decoder_steps = 200
|
||||
self.memory_dim = memory_dim
|
||||
|
|
File diff suppressed because one or more lines are too long
Binary file not shown.
|
@ -48,12 +48,26 @@ def copy_config_file(config_file, path):
|
|||
shutil.copyfile(config_file, out_path)
|
||||
|
||||
|
||||
def _trim_model_state_dict(state_dict):
|
||||
r"""Remove 'module.' prefix from state dictionary. It is necessary as it
|
||||
is loded for the next time by model.load_state(). Otherwise, it complains
|
||||
about the torch.DataParallel()"""
|
||||
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = k[7:] # remove `module.`
|
||||
new_state_dict[name] = v
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
|
||||
current_step, epoch):
|
||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
print("\n | > Checkpoint saving : {}".format(checkpoint_path))
|
||||
state = {'model': model.state_dict(),
|
||||
|
||||
new_state_dict = _trim_model_state_dict(model.state_dict())
|
||||
state = {'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
|
@ -65,7 +79,8 @@ def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
|
|||
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||
current_step, epoch):
|
||||
if model_loss < best_loss:
|
||||
state = {'model': model.state_dict(),
|
||||
new_state_dict = _trim_model_state_dict(model.state_dict())
|
||||
state = {'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
|
|
Loading…
Reference in New Issue