mirror of https://github.com/coqui-ai/TTS.git
some more tests for model saving
This commit is contained in:
parent
21a4ee63fe
commit
9b7b5e238e
3
train.py
3
train.py
|
@ -218,8 +218,7 @@ def main(args):
|
||||||
if c.checkpoint:
|
if c.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model, optimizer, linear_loss.data[0],
|
save_checkpoint(model, optimizer, linear_loss.data[0],
|
||||||
best_loss, OUT_PATH,
|
OUT_PATH, current_step, epoch)
|
||||||
current_step, epoch)
|
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
const_spec = linear_output[0].data.cpu().numpy()
|
const_spec = linear_output[0].data.cpu().numpy()
|
||||||
|
|
|
@ -60,7 +60,7 @@ def _trim_model_state_dict(state_dict):
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
|
def save_checkpoint(model, optimizer, model_loss, out_path,
|
||||||
current_step, epoch):
|
current_step, epoch):
|
||||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||||
|
|
Loading…
Reference in New Issue