chore(utils.io): remove unused code

These are all available in Trainer.
This commit is contained in:
Enno Hermann 2023-11-16 23:52:28 +01:00
parent 39fe38bda4
commit 5119e651a1
1 changed files with 0 additions and 104 deletions

View File

@ -1,4 +1,3 @@
import datetime
import json import json
import os import os
import pickle as pickle_tts import pickle as pickle_tts
@ -8,7 +7,6 @@ from typing import Any, Callable, Dict, Union
import fsspec import fsspec
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from trainer.io import save_fsspec
from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.generic_utils import get_user_data_dir
@ -101,105 +99,3 @@ def load_checkpoint(
if eval: if eval:
model.eval() model.eval()
return model, state return model, state
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
if isinstance(optimizer, list):
optimizer_state = [optim.state_dict() for optim in optimizer]
elif optimizer.__class__.__name__ == "CapacitronOptimizer":
optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()]
else:
optimizer_state = optimizer.state_dict() if optimizer is not None else None
if isinstance(scaler, list):
scaler_state = [s.state_dict() for s in scaler]
else:
scaler_state = scaler.state_dict() if scaler is not None else None
if isinstance(config, Coqpit):
config = config.to_dict()
state = {
"config": config,
"model": model_state,
"optimizer": optimizer_state,
"scaler": scaler_state,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
save_fsspec(state, output_path)
def save_checkpoint(
config,
model,
optimizer,
scaler,
current_step,
epoch,
output_folder,
**kwargs,
):
file_name = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(output_folder, file_name)
print("\n > CHECKPOINT : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
**kwargs,
)
def save_best_model(
current_loss,
best_loss,
config,
model,
optimizer,
scaler,
current_step,
epoch,
out_path,
keep_all_best=False,
keep_after=10000,
**kwargs,
):
if current_loss < best_loss:
best_model_name = f"best_model_{current_step}.pth"
checkpoint_path = os.path.join(out_path, best_model_name)
print(" > BEST MODEL : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
model_loss=current_loss,
**kwargs,
)
fs = fsspec.get_mapper(out_path).fs
# only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after):
model_names = fs.glob(os.path.join(out_path, "best_model*.pth"))
for model_name in model_names:
if os.path.basename(model_name) != best_model_name:
fs.rm(model_name)
# create a shortcut which always points to the currently best model
shortcut_name = "best_model.pth"
shortcut_path = os.path.join(out_path, shortcut_name)
fs.copy(checkpoint_path, shortcut_path)
best_loss = current_loss
return best_loss