mirror of https://github.com/coqui-ai/TTS.git
solve pickling models after module name change
This commit is contained in:
parent
df19428ec6
commit
540d811dd5
|
@ -1,10 +1,17 @@
|
|||
import os
|
||||
import torch
|
||||
import datetime
|
||||
import pickle as pickle_tts
|
||||
|
||||
from TTS.utils.io import RenamingUnpickler
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
try:
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
except ModuleNotFoundError:
|
||||
pickle_tts.Unpickler = RenamingUnpickler
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts)
|
||||
model.load_state_dict(state['model'])
|
||||
if amp and 'amp' in state:
|
||||
amp.load_state_dict(state['amp'])
|
||||
|
|
|
@ -1,5 +1,13 @@
|
|||
import re
|
||||
import json
|
||||
import pickle as pickle_tts
|
||||
|
||||
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||
"""Overload default pickler to solve module renaming problem"""
|
||||
def find_class(self, module, name):
|
||||
if 'mozilla_voice_tts' in module :
|
||||
module = module.replace('mozilla_voice_tts', 'TTS')
|
||||
return super().find_class(module, name)
|
||||
|
||||
class AttrDict(dict):
|
||||
"""A custom dict which converts dict keys
|
||||
|
|
|
@ -1,6 +1,21 @@
|
|||
import os
|
||||
import torch
|
||||
import datetime
|
||||
import pickle as pickle_tts
|
||||
|
||||
from TTS.utils.io import RenamingUnpickler
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
||||
try:
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
except ModuleNotFoundError:
|
||||
pickle_tts.Unpickler = RenamingUnpickler
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts)
|
||||
model.load_state_dict(state['model'])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
return model, state
|
||||
|
||||
|
||||
def save_model(model, optimizer, scheduler, model_disc, optimizer_disc,
|
||||
|
|
Loading…
Reference in New Issue