mirror of https://github.com/coqui-ai/TTS.git
solve pickling models after module name change
This commit is contained in:
parent
df19428ec6
commit
540d811dd5
|
@ -1,10 +1,17 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import datetime
|
import datetime
|
||||||
|
import pickle as pickle_tts
|
||||||
|
|
||||||
|
from TTS.utils.io import RenamingUnpickler
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
try:
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pickle_tts.Unpickler = RenamingUnpickler
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts)
|
||||||
model.load_state_dict(state['model'])
|
model.load_state_dict(state['model'])
|
||||||
if amp and 'amp' in state:
|
if amp and 'amp' in state:
|
||||||
amp.load_state_dict(state['amp'])
|
amp.load_state_dict(state['amp'])
|
||||||
|
|
|
@ -1,5 +1,13 @@
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
|
import pickle as pickle_tts
|
||||||
|
|
||||||
|
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||||
|
"""Overload default pickler to solve module renaming problem"""
|
||||||
|
def find_class(self, module, name):
|
||||||
|
if 'mozilla_voice_tts' in module :
|
||||||
|
module = module.replace('mozilla_voice_tts', 'TTS')
|
||||||
|
return super().find_class(module, name)
|
||||||
|
|
||||||
class AttrDict(dict):
|
class AttrDict(dict):
|
||||||
"""A custom dict which converts dict keys
|
"""A custom dict which converts dict keys
|
||||||
|
|
|
@ -1,6 +1,21 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import datetime
|
import datetime
|
||||||
|
import pickle as pickle_tts
|
||||||
|
|
||||||
|
from TTS.utils.io import RenamingUnpickler
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
||||||
|
try:
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pickle_tts.Unpickler = RenamingUnpickler
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts)
|
||||||
|
model.load_state_dict(state['model'])
|
||||||
|
if use_cuda:
|
||||||
|
model.cuda()
|
||||||
|
return model, state
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, optimizer, scheduler, model_disc, optimizer_disc,
|
def save_model(model, optimizer, scheduler, model_disc, optimizer_disc,
|
||||||
|
|
Loading…
Reference in New Issue