solve pickling models after module name change

This commit is contained in:
erogol 2020-09-11 12:03:39 +02:00
parent df19428ec6
commit 540d811dd5
3 changed files with 31 additions and 1 deletions

View File

@ -1,10 +1,17 @@
import os
import torch
import datetime
import pickle as pickle_tts
from TTS.utils.io import RenamingUnpickler
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
try:
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts)
model.load_state_dict(state['model'])
if amp and 'amp' in state:
amp.load_state_dict(state['amp'])

View File

@ -1,5 +1,13 @@
import re
import json
import pickle as pickle_tts
class RenamingUnpickler(pickle_tts.Unpickler):
"""Overload default pickler to solve module renaming problem"""
def find_class(self, module, name):
if 'mozilla_voice_tts' in module :
module = module.replace('mozilla_voice_tts', 'TTS')
return super().find_class(module, name)
class AttrDict(dict):
"""A custom dict which converts dict keys

View File

@ -1,6 +1,21 @@
import os
import torch
import datetime
import pickle as pickle_tts
from TTS.utils.io import RenamingUnpickler
def load_checkpoint(model, checkpoint_path, use_cuda=False):
try:
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts)
model.load_state_dict(state['model'])
if use_cuda:
model.cuda()
return model, state
def save_model(model, optimizer, scheduler, model_disc, optimizer_disc,