Implement unified IO utils

This commit is contained in:
Eren Gölge 2021-06-18 13:23:08 +02:00
parent c7aad884cd
commit 98298ee671
3 changed files with 121 additions and 248 deletions

View File

@ -1,120 +0,0 @@
import datetime
import os
import pickle as pickle_tts
import torch
from TTS.utils.io import RenamingUnpickler
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
"""Load ```TTS.tts.models``` checkpoints.
Args:
model (TTS.tts.models): model object to load the weights for.
checkpoint_path (string): checkpoint file path.
amp (apex.amp, optional): Apex amp abject to load apex related state vars. Defaults to None.
use_cuda (bool, optional): load model to GPU if True. Defaults to False.
Returns:
[type]: [description]
"""
try:
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
model.load_state_dict(state["model"])
if amp and "amp" in state:
amp.load_state_dict(state["amp"])
if use_cuda:
model.cuda()
# set model stepsize
if hasattr(model.decoder, "r"):
model.decoder.set_r(state["r"])
print(" > Model r: ", state["r"])
if eval:
model.eval()
return model, state
def save_model(model, optimizer, current_step, epoch, r, output_path, characters, amp_state_dict=None, **kwargs):
"""Save ```TTS.tts.models``` states with extra fields.
Args:
model (TTS.tts.models.Model): models object to be saved.
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
current_step (int): current number of training steps.
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None.
"""
if hasattr(model, "module"):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
state = {
"model": model_state,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
"r": r,
"characters": characters,
}
if amp_state_dict:
state["amp"] = amp_state_dict
state.update(kwargs)
torch.save(state, output_path)
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs):
"""Save model checkpoint, intended for saving checkpoints at training.
Args:
model (TTS.tts.models.Model): models object to be saved.
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
current_step (int): current number of training steps.
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
"""
file_name = "checkpoint_{}.pth.tar".format(current_step)
checkpoint_path = os.path.join(output_folder, file_name)
print(" > CHECKPOINT : {}".format(checkpoint_path))
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs)
def save_best_model(
target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs
):
"""Save model checkpoint, intended for saving the best model after each epoch.
It compares the current model loss with the best loss so far and saves the
model if the current loss is better.
Args:
target_loss (float): current model loss.
best_loss (float): best loss so far.
model (TTS.tts.models.Model): models object to be saved.
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
current_step (int): current number of training steps.
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
Returns:
float: updated current best loss.
"""
if target_loss < best_loss:
file_name = "best_model.pth.tar"
checkpoint_path = os.path.join(output_folder, file_name)
print(" >> BEST MODEL : {}".format(checkpoint_path))
save_model(
model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs
)
best_loss = target_loss
return best_loss

View File

@ -1,7 +1,12 @@
import datetime
import glob
import os
import pickle as pickle_tts
from shutil import copyfile
import torch
from coqpit import Coqpit
class RenamingUnpickler(pickle_tts.Unpickler):
"""Overload default pickler to solve module renaming problem"""
@ -41,3 +46,119 @@ def copy_model_files(config, out_path, new_fields):
config.audio.stats_path,
copy_stats_path,
)
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
try:
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
model.load_state_dict(state["model"])
if use_cuda:
model.cuda()
if eval:
model.eval()
return model, state
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
if isinstance(optimizer, list):
optimizer_state = [optim.state_dict() for optim in optimizer]
else:
optimizer_state = optimizer.state_dict() if optimizer is not None else None
if isinstance(scaler, list):
scaler_state = [s.state_dict() for s in scaler]
else:
scaler_state = scaler.state_dict() if scaler is not None else None
if isinstance(config, Coqpit):
config = config.to_dict()
state = {
"config": config,
"model": model_state,
"optimizer": optimizer_state,
"scaler": scaler_state,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
torch.save(state, output_path)
def save_checkpoint(
config,
model,
optimizer,
scaler,
current_step,
epoch,
output_folder,
**kwargs,
):
file_name = "checkpoint_{}.pth.tar".format(current_step)
checkpoint_path = os.path.join(output_folder, file_name)
print("\n > CHECKPOINT : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
**kwargs,
)
def save_best_model(
current_loss,
best_loss,
config,
model,
optimizer,
scaler,
current_step,
epoch,
out_path,
keep_all_best=False,
keep_after=10000,
**kwargs,
):
if current_loss < best_loss:
best_model_name = f"best_model_{current_step}.pth.tar"
checkpoint_path = os.path.join(out_path, best_model_name)
print(" > BEST MODEL : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
model_loss=current_loss,
**kwargs,
)
# only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after):
model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar"))
for model_name in model_names:
if os.path.basename(model_name) == best_model_name:
continue
os.remove(model_name)
# create symlink to best model for convinience
link_name = "best_model.pth.tar"
link_path = os.path.join(out_path, link_name)
if os.path.islink(link_path) or os.path.isfile(link_path):
os.remove(link_path)
os.symlink(best_model_name, os.path.join(out_path, link_name))
best_loss = current_loss
return best_loss

View File

@ -1,128 +0,0 @@
import datetime
import glob
import os
import pickle as pickle_tts
import torch
from TTS.utils.io import RenamingUnpickler
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
try:
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
model.load_state_dict(state["model"])
if use_cuda:
model.cuda()
if eval:
model.eval()
return model, state
def save_model(
model, optimizer, scheduler, model_disc, optimizer_disc, scheduler_disc, current_step, epoch, output_path, **kwargs
):
if hasattr(model, "module"):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
model_disc_state = model_disc.state_dict() if model_disc is not None else None
optimizer_state = optimizer.state_dict() if optimizer is not None else None
optimizer_disc_state = optimizer_disc.state_dict() if optimizer_disc is not None else None
scheduler_state = scheduler.state_dict() if scheduler is not None else None
scheduler_disc_state = scheduler_disc.state_dict() if scheduler_disc is not None else None
state = {
"model": model_state,
"optimizer": optimizer_state,
"scheduler": scheduler_state,
"model_disc": model_disc_state,
"optimizer_disc": optimizer_disc_state,
"scheduler_disc": scheduler_disc_state,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
torch.save(state, output_path)
def save_checkpoint(
model,
optimizer,
scheduler,
model_disc,
optimizer_disc,
scheduler_disc,
current_step,
epoch,
output_folder,
**kwargs,
):
file_name = "checkpoint_{}.pth.tar".format(current_step)
checkpoint_path = os.path.join(output_folder, file_name)
print(" > CHECKPOINT : {}".format(checkpoint_path))
save_model(
model,
optimizer,
scheduler,
model_disc,
optimizer_disc,
scheduler_disc,
current_step,
epoch,
checkpoint_path,
**kwargs,
)
def save_best_model(
current_loss,
best_loss,
model,
optimizer,
scheduler,
model_disc,
optimizer_disc,
scheduler_disc,
current_step,
epoch,
out_path,
keep_all_best=False,
keep_after=10000,
**kwargs,
):
if current_loss < best_loss:
best_model_name = f"best_model_{current_step}.pth.tar"
checkpoint_path = os.path.join(out_path, best_model_name)
print(" > BEST MODEL : {}".format(checkpoint_path))
save_model(
model,
optimizer,
scheduler,
model_disc,
optimizer_disc,
scheduler_disc,
current_step,
epoch,
checkpoint_path,
model_loss=current_loss,
**kwargs,
)
# only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after):
model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar"))
for model_name in model_names:
if os.path.basename(model_name) == best_model_name:
continue
os.remove(model_name)
# create symlink to best model for convinience
link_name = "best_model.pth.tar"
link_path = os.path.join(out_path, link_name)
if os.path.islink(link_path) or os.path.isfile(link_path):
os.remove(link_path)
os.symlink(best_model_name, os.path.join(out_path, link_name))
best_loss = current_loss
return best_loss