diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py index 63e04283..fe94d98d 100644 --- a/TTS/tts/utils/io.py +++ b/TTS/tts/utils/io.py @@ -38,7 +38,15 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False return model, state -def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs): +def save_model(model, + optimizer, + current_step, + epoch, + r, + output_path, + characters, + amp_state_dict=None, + **kwargs): """Save ```TTS.tts.models``` states with extra fields. Args: @@ -48,6 +56,7 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_ epoch (int): current number of training epochs. r (int): model reduction rate for Tacotron models. output_path (str): output path to save the model file. + characters (list): list of characters used in the model. amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None. """ if hasattr(model, 'module'): @@ -60,7 +69,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_ 'step': current_step, 'epoch': epoch, 'date': datetime.date.today().strftime("%B %d, %Y"), - 'r': r + 'r': r, + 'characters': characters } if amp_state_dict: state['amp'] = amp_state_dict @@ -68,7 +78,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_ torch.save(state, output_path) -def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs): +def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, + characters, **kwargs): """Save model checkpoint, intended for saving checkpoints at training. Args: @@ -78,14 +89,16 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k epoch (int): current number of training epochs. r (int): model reduction rate for Tacotron models. output_path (str): output path to save the model file. + characters (list): list of characters used in the model. """ file_name = 'checkpoint_{}.pth.tar'.format(current_step) checkpoint_path = os.path.join(output_folder, file_name) print(" > CHECKPOINT : {}".format(checkpoint_path)) - save_model(model, optimizer, current_step, epoch, r, checkpoint_path, **kwargs) + save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs) -def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs): +def save_best_model(target_loss, best_loss, model, optimizer, current_step, + epoch, r, output_folder, characters, **kwargs): """Save model checkpoint, intended for saving the best model after each epoch. It compares the current model loss with the best loss so far and saves the model if the current loss is better. @@ -99,6 +112,7 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoc epoch (int): current number of training epochs. r (int): model reduction rate for Tacotron models. output_path (str): output path to save the model file. + characters (list): list of characters used in the model. Returns: float: updated current best loss. @@ -107,6 +121,6 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoc file_name = 'best_model.pth.tar' checkpoint_path = os.path.join(output_folder, file_name) print(" >> BEST MODEL : {}".format(checkpoint_path)) - save_model(model, optimizer, current_step, epoch, r, checkpoint_path, model_loss=target_loss, **kwargs) + save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs) best_loss = target_loss return best_loss