enable saving model characters in io.py

This commit is contained in:
Eren Gölge 2021-02-12 12:04:41 +00:00 committed by Eren Gölge
parent f9fe167537
commit 0b33acdcca
1 changed files with 20 additions and 6 deletions

View File

@ -38,7 +38,15 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False
return model, state
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
def save_model(model,
optimizer,
current_step,
epoch,
r,
output_path,
characters,
amp_state_dict=None,
**kwargs):
"""Save ```TTS.tts.models``` states with extra fields.
Args:
@ -48,6 +56,7 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None.
"""
if hasattr(model, 'module'):
@ -60,7 +69,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
'step': current_step,
'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': r
'r': r,
'characters': characters
}
if amp_state_dict:
state['amp'] = amp_state_dict
@ -68,7 +78,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
torch.save(state, output_path)
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder,
characters, **kwargs):
"""Save model checkpoint, intended for saving checkpoints at training.
Args:
@ -78,14 +89,16 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
"""
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(output_folder, file_name)
print(" > CHECKPOINT : {}".format(checkpoint_path))
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, **kwargs)
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs)
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs):
def save_best_model(target_loss, best_loss, model, optimizer, current_step,
epoch, r, output_folder, characters, **kwargs):
"""Save model checkpoint, intended for saving the best model after each epoch.
It compares the current model loss with the best loss so far and saves the
model if the current loss is better.
@ -99,6 +112,7 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoc
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
Returns:
float: updated current best loss.
@ -107,6 +121,6 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoc
file_name = 'best_model.pth.tar'
checkpoint_path = os.path.join(output_folder, file_name)
print(" >> BEST MODEL : {}".format(checkpoint_path))
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, model_loss=target_loss, **kwargs)
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs)
best_loss = target_loss
return best_loss