From 540d811dd52b5598a7cd21cbbcf197b0bfbeab62 Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 11 Sep 2020 12:03:39 +0200 Subject: [PATCH] solve pickling models after module name change --- TTS/tts/utils/io.py | 9 ++++++++- TTS/utils/io.py | 8 ++++++++ TTS/vocoder/utils/io.py | 15 +++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py index da5c8b27..bf5e13d8 100644 --- a/TTS/tts/utils/io.py +++ b/TTS/tts/utils/io.py @@ -1,10 +1,17 @@ import os import torch import datetime +import pickle as pickle_tts + +from TTS.utils.io import RenamingUnpickler def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False): - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + try: + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + except ModuleNotFoundError: + pickle_tts.Unpickler = RenamingUnpickler + state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts) model.load_state_dict(state['model']) if amp and 'amp' in state: amp.load_state_dict(state['amp']) diff --git a/TTS/utils/io.py b/TTS/utils/io.py index c96703ed..c54d2e9f 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -1,5 +1,13 @@ import re import json +import pickle as pickle_tts + +class RenamingUnpickler(pickle_tts.Unpickler): + """Overload default pickler to solve module renaming problem""" + def find_class(self, module, name): + if 'mozilla_voice_tts' in module : + module = module.replace('mozilla_voice_tts', 'TTS') + return super().find_class(module, name) class AttrDict(dict): """A custom dict which converts dict keys diff --git a/TTS/vocoder/utils/io.py b/TTS/vocoder/utils/io.py index 734714e0..640334f1 100644 --- a/TTS/vocoder/utils/io.py +++ b/TTS/vocoder/utils/io.py @@ -1,6 +1,21 @@ import os import torch import datetime +import pickle as pickle_tts + +from TTS.utils.io import RenamingUnpickler + + +def load_checkpoint(model, checkpoint_path, use_cuda=False): + try: + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + except ModuleNotFoundError: + pickle_tts.Unpickler = RenamingUnpickler + state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts) + model.load_state_dict(state['model']) + if use_cuda: + model.cuda() + return model, state def save_model(model, optimizer, scheduler, model_disc, optimizer_disc,