mirror of https://github.com/coqui-ai/TTS.git
fix checkpointing
This commit is contained in:
parent
0ec42fa279
commit
88bde77061
16
utils/io.py
16
utils/io.py
|
@ -47,9 +47,7 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
|||
return model, state
|
||||
|
||||
|
||||
def save_model(model, optimizer, current_step, epoch, r, output_folder, file_name, **kwargs):
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
|
||||
def save_model(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
'model': new_state_dict,
|
||||
|
@ -60,19 +58,21 @@ def save_model(model, optimizer, current_step, epoch, r, output_folder, file_nam
|
|||
'r': model.decoder.r
|
||||
}
|
||||
state.update(kwargs)
|
||||
torch.save(state, checkpoint_path)
|
||||
torch.save(state, output_path)
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
||||
print(" > CHECKPOINT : {}".format(checkpoint_path))
|
||||
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
save_model(model, optimizer, current_step, epoch ,r, output_folder, file_name, **kwargs)
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print(" > CHECKPOINT : {}".format(checkpoint_path))
|
||||
save_model(model, optimizer, current_step, epoch ,r, checkpoint_path, **kwargs)
|
||||
|
||||
|
||||
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
||||
if target_loss < best_loss:
|
||||
print(" > BEST MODEL : {}".format(checkpoint_path))
|
||||
file_name = 'best_model.pth.tar'
|
||||
save_model(model, optimizer, current_step, epoch ,r, output_folder, file_name, model_loss=target_loss)
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print(" > BEST MODEL : {}".format(checkpoint_path))
|
||||
save_model(model, optimizer, current_step, epoch ,r, checkpoint_path, model_loss=target_loss)
|
||||
best_loss = target_loss
|
||||
return best_loss
|
Loading…
Reference in New Issue