mirror of https://github.com/coqui-ai/TTS.git
some more tests for model saving
This commit is contained in:
parent
21a4ee63fe
commit
9b7b5e238e
3
train.py
3
train.py
|
@ -218,8 +218,7 @@ def main(args):
|
|||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, linear_loss.data[0],
|
||||
best_loss, OUT_PATH,
|
||||
current_step, epoch)
|
||||
OUT_PATH, current_step, epoch)
|
||||
|
||||
# Diagnostic visualizations
|
||||
const_spec = linear_output[0].data.cpu().numpy()
|
||||
|
|
|
@ -60,7 +60,7 @@ def _trim_model_state_dict(state_dict):
|
|||
return new_state_dict
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
|
||||
def save_checkpoint(model, optimizer, model_loss, out_path,
|
||||
current_step, epoch):
|
||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
|
|
Loading…
Reference in New Issue