From acd96a4940887c193a84f9b600f1f9d7ae0aec86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 18 Jun 2021 13:23:08 +0200 Subject: [PATCH] Implement unified IO utils --- TTS/tts/utils/io.py | 120 ------------------------------------- TTS/utils/io.py | 121 +++++++++++++++++++++++++++++++++++++ TTS/vocoder/utils/io.py | 128 ---------------------------------------- 3 files changed, 121 insertions(+), 248 deletions(-) delete mode 100644 TTS/tts/utils/io.py delete mode 100644 TTS/vocoder/utils/io.py diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py deleted file mode 100644 index bb8432fa..00000000 --- a/TTS/tts/utils/io.py +++ /dev/null @@ -1,120 +0,0 @@ -import datetime -import os -import pickle as pickle_tts - -import torch - -from TTS.utils.io import RenamingUnpickler - - -def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False): # pylint: disable=redefined-builtin - """Load ```TTS.tts.models``` checkpoints. - - Args: - model (TTS.tts.models): model object to load the weights for. - checkpoint_path (string): checkpoint file path. - amp (apex.amp, optional): Apex amp abject to load apex related state vars. Defaults to None. - use_cuda (bool, optional): load model to GPU if True. Defaults to False. - - Returns: - [type]: [description] - """ - try: - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) - except ModuleNotFoundError: - pickle_tts.Unpickler = RenamingUnpickler - state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) - model.load_state_dict(state["model"]) - if amp and "amp" in state: - amp.load_state_dict(state["amp"]) - if use_cuda: - model.cuda() - # set model stepsize - if hasattr(model.decoder, "r"): - model.decoder.set_r(state["r"]) - print(" > Model r: ", state["r"]) - if eval: - model.eval() - return model, state - - -def save_model(model, optimizer, current_step, epoch, r, output_path, characters, amp_state_dict=None, **kwargs): - """Save ```TTS.tts.models``` states with extra fields. - - Args: - model (TTS.tts.models.Model): models object to be saved. - optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training. - current_step (int): current number of training steps. - epoch (int): current number of training epochs. - r (int): model reduction rate for Tacotron models. - output_path (str): output path to save the model file. - characters (list): list of characters used in the model. - amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None. - """ - if hasattr(model, "module"): - model_state = model.module.state_dict() - else: - model_state = model.state_dict() - state = { - "model": model_state, - "optimizer": optimizer.state_dict() if optimizer is not None else None, - "step": current_step, - "epoch": epoch, - "date": datetime.date.today().strftime("%B %d, %Y"), - "r": r, - "characters": characters, - } - if amp_state_dict: - state["amp"] = amp_state_dict - state.update(kwargs) - torch.save(state, output_path) - - -def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs): - """Save model checkpoint, intended for saving checkpoints at training. - - Args: - model (TTS.tts.models.Model): models object to be saved. - optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training. - current_step (int): current number of training steps. - epoch (int): current number of training epochs. - r (int): model reduction rate for Tacotron models. - output_path (str): output path to save the model file. - characters (list): list of characters used in the model. - """ - file_name = "checkpoint_{}.pth.tar".format(current_step) - checkpoint_path = os.path.join(output_folder, file_name) - print(" > CHECKPOINT : {}".format(checkpoint_path)) - save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs) - - -def save_best_model( - target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs -): - """Save model checkpoint, intended for saving the best model after each epoch. - It compares the current model loss with the best loss so far and saves the - model if the current loss is better. - - Args: - target_loss (float): current model loss. - best_loss (float): best loss so far. - model (TTS.tts.models.Model): models object to be saved. - optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training. - current_step (int): current number of training steps. - epoch (int): current number of training epochs. - r (int): model reduction rate for Tacotron models. - output_path (str): output path to save the model file. - characters (list): list of characters used in the model. - - Returns: - float: updated current best loss. - """ - if target_loss < best_loss: - file_name = "best_model.pth.tar" - checkpoint_path = os.path.join(output_folder, file_name) - print(" >> BEST MODEL : {}".format(checkpoint_path)) - save_model( - model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs - ) - best_loss = target_loss - return best_loss diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 62d972f1..871cff6c 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -1,7 +1,12 @@ +import datetime +import glob import os import pickle as pickle_tts from shutil import copyfile +import torch +from coqpit import Coqpit + class RenamingUnpickler(pickle_tts.Unpickler): """Overload default pickler to solve module renaming problem""" @@ -41,3 +46,119 @@ def copy_model_files(config, out_path, new_fields): config.audio.stats_path, copy_stats_path, ) + + +def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin + try: + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + except ModuleNotFoundError: + pickle_tts.Unpickler = RenamingUnpickler + state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) + model.load_state_dict(state["model"]) + if use_cuda: + model.cuda() + if eval: + model.eval() + return model, state + + +def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): + if hasattr(model, "module"): + model_state = model.module.state_dict() + else: + model_state = model.state_dict() + if isinstance(optimizer, list): + optimizer_state = [optim.state_dict() for optim in optimizer] + else: + optimizer_state = optimizer.state_dict() if optimizer is not None else None + + if isinstance(scaler, list): + scaler_state = [s.state_dict() for s in scaler] + else: + scaler_state = scaler.state_dict() if scaler is not None else None + + if isinstance(config, Coqpit): + config = config.to_dict() + + state = { + "config": config, + "model": model_state, + "optimizer": optimizer_state, + "scaler": scaler_state, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + state.update(kwargs) + torch.save(state, output_path) + + +def save_checkpoint( + config, + model, + optimizer, + scaler, + current_step, + epoch, + output_folder, + **kwargs, +): + file_name = "checkpoint_{}.pth.tar".format(current_step) + checkpoint_path = os.path.join(output_folder, file_name) + print("\n > CHECKPOINT : {}".format(checkpoint_path)) + save_model( + config, + model, + optimizer, + scaler, + current_step, + epoch, + checkpoint_path, + **kwargs, + ) + + +def save_best_model( + current_loss, + best_loss, + config, + model, + optimizer, + scaler, + current_step, + epoch, + out_path, + keep_all_best=False, + keep_after=10000, + **kwargs, +): + if current_loss < best_loss: + best_model_name = f"best_model_{current_step}.pth.tar" + checkpoint_path = os.path.join(out_path, best_model_name) + print(" > BEST MODEL : {}".format(checkpoint_path)) + save_model( + config, + model, + optimizer, + scaler, + current_step, + epoch, + checkpoint_path, + model_loss=current_loss, + **kwargs, + ) + # only delete previous if current is saved successfully + if not keep_all_best or (current_step < keep_after): + model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar")) + for model_name in model_names: + if os.path.basename(model_name) == best_model_name: + continue + os.remove(model_name) + # create symlink to best model for convinience + link_name = "best_model.pth.tar" + link_path = os.path.join(out_path, link_name) + if os.path.islink(link_path) or os.path.isfile(link_path): + os.remove(link_path) + os.symlink(best_model_name, os.path.join(out_path, link_name)) + best_loss = current_loss + return best_loss diff --git a/TTS/vocoder/utils/io.py b/TTS/vocoder/utils/io.py deleted file mode 100644 index 9c67535f..00000000 --- a/TTS/vocoder/utils/io.py +++ /dev/null @@ -1,128 +0,0 @@ -import datetime -import glob -import os -import pickle as pickle_tts - -import torch - -from TTS.utils.io import RenamingUnpickler - - -def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin - try: - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) - except ModuleNotFoundError: - pickle_tts.Unpickler = RenamingUnpickler - state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) - model.load_state_dict(state["model"]) - if use_cuda: - model.cuda() - if eval: - model.eval() - return model, state - - -def save_model( - model, optimizer, scheduler, model_disc, optimizer_disc, scheduler_disc, current_step, epoch, output_path, **kwargs -): - if hasattr(model, "module"): - model_state = model.module.state_dict() - else: - model_state = model.state_dict() - model_disc_state = model_disc.state_dict() if model_disc is not None else None - optimizer_state = optimizer.state_dict() if optimizer is not None else None - optimizer_disc_state = optimizer_disc.state_dict() if optimizer_disc is not None else None - scheduler_state = scheduler.state_dict() if scheduler is not None else None - scheduler_disc_state = scheduler_disc.state_dict() if scheduler_disc is not None else None - state = { - "model": model_state, - "optimizer": optimizer_state, - "scheduler": scheduler_state, - "model_disc": model_disc_state, - "optimizer_disc": optimizer_disc_state, - "scheduler_disc": scheduler_disc_state, - "step": current_step, - "epoch": epoch, - "date": datetime.date.today().strftime("%B %d, %Y"), - } - state.update(kwargs) - torch.save(state, output_path) - - -def save_checkpoint( - model, - optimizer, - scheduler, - model_disc, - optimizer_disc, - scheduler_disc, - current_step, - epoch, - output_folder, - **kwargs, -): - file_name = "checkpoint_{}.pth.tar".format(current_step) - checkpoint_path = os.path.join(output_folder, file_name) - print(" > CHECKPOINT : {}".format(checkpoint_path)) - save_model( - model, - optimizer, - scheduler, - model_disc, - optimizer_disc, - scheduler_disc, - current_step, - epoch, - checkpoint_path, - **kwargs, - ) - - -def save_best_model( - current_loss, - best_loss, - model, - optimizer, - scheduler, - model_disc, - optimizer_disc, - scheduler_disc, - current_step, - epoch, - out_path, - keep_all_best=False, - keep_after=10000, - **kwargs, -): - if current_loss < best_loss: - best_model_name = f"best_model_{current_step}.pth.tar" - checkpoint_path = os.path.join(out_path, best_model_name) - print(" > BEST MODEL : {}".format(checkpoint_path)) - save_model( - model, - optimizer, - scheduler, - model_disc, - optimizer_disc, - scheduler_disc, - current_step, - epoch, - checkpoint_path, - model_loss=current_loss, - **kwargs, - ) - # only delete previous if current is saved successfully - if not keep_all_best or (current_step < keep_after): - model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar")) - for model_name in model_names: - if os.path.basename(model_name) == best_model_name: - continue - os.remove(model_name) - # create symlink to best model for convinience - link_name = "best_model.pth.tar" - link_path = os.path.join(out_path, link_name) - if os.path.islink(link_path) or os.path.isfile(link_path): - os.remove(link_path) - os.symlink(best_model_name, os.path.join(out_path, link_name)) - best_loss = current_loss - return best_loss