enable saving model characters in io.py

This commit is contained in:
Eren Gölge 2021-02-12 12:04:41 +00:00
parent 918f007a11
commit 2abfff17f9
1 changed files with 20 additions and 6 deletions

View File

@ -38,7 +38,15 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False
return model, state return model, state
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs): def save_model(model,
optimizer,
current_step,
epoch,
r,
output_path,
characters,
amp_state_dict=None,
**kwargs):
"""Save ```TTS.tts.models``` states with extra fields. """Save ```TTS.tts.models``` states with extra fields.
Args: Args:
@ -48,6 +56,7 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
epoch (int): current number of training epochs. epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models. r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file. output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None. amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None.
""" """
if hasattr(model, 'module'): if hasattr(model, 'module'):
@ -60,7 +69,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
'step': current_step, 'step': current_step,
'epoch': epoch, 'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y"), 'date': datetime.date.today().strftime("%B %d, %Y"),
'r': r 'r': r,
'characters': characters
} }
if amp_state_dict: if amp_state_dict:
state['amp'] = amp_state_dict state['amp'] = amp_state_dict
@ -68,7 +78,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
torch.save(state, output_path) torch.save(state, output_path)
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs): def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder,
characters, **kwargs):
"""Save model checkpoint, intended for saving checkpoints at training. """Save model checkpoint, intended for saving checkpoints at training.
Args: Args:
@ -78,14 +89,16 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k
epoch (int): current number of training epochs. epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models. r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file. output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
""" """
file_name = 'checkpoint_{}.pth.tar'.format(current_step) file_name = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(output_folder, file_name) checkpoint_path = os.path.join(output_folder, file_name)
print(" > CHECKPOINT : {}".format(checkpoint_path)) print(" > CHECKPOINT : {}".format(checkpoint_path))
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, **kwargs) save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs)
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs): def save_best_model(target_loss, best_loss, model, optimizer, current_step,
epoch, r, output_folder, characters, **kwargs):
"""Save model checkpoint, intended for saving the best model after each epoch. """Save model checkpoint, intended for saving the best model after each epoch.
It compares the current model loss with the best loss so far and saves the It compares the current model loss with the best loss so far and saves the
model if the current loss is better. model if the current loss is better.
@ -99,6 +112,7 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoc
epoch (int): current number of training epochs. epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models. r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file. output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
Returns: Returns:
float: updated current best loss. float: updated current best loss.
@ -107,6 +121,6 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoc
file_name = 'best_model.pth.tar' file_name = 'best_model.pth.tar'
checkpoint_path = os.path.join(output_folder, file_name) checkpoint_path = os.path.join(output_folder, file_name)
print(" >> BEST MODEL : {}".format(checkpoint_path)) print(" >> BEST MODEL : {}".format(checkpoint_path))
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, model_loss=target_loss, **kwargs) save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs)
best_loss = target_loss best_loss = target_loss
return best_loss return best_loss