refactor: use save_checkpoint()/save_best_model() from Trainer

This commit is contained in:
Enno Hermann 2023-11-17 00:39:11 +01:00
parent 96678c7ba2
commit 0fb0d67de7
6 changed files with 23 additions and 87 deletions

View File

@ -8,12 +8,12 @@ import traceback
import torch
from torch.utils.data import DataLoader
from trainer.io import copy_model_files
from trainer.io import copy_model_files, save_best_model, save_checkpoint
from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer
from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
@ -222,7 +222,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
if global_step % c.save_step == 0:
# save model
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch)
save_checkpoint(
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
)
end_time = time.time()
@ -245,7 +247,18 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
flush=True,
)
# save the best checkpoint
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
best_loss = save_best_model(
eval_loss,
best_loss,
c,
model,
optimizer,
None,
global_step,
epoch,
OUT_PATH,
criterion=criterion.state_dict(),
)
model.train()
return best_loss, global_step

View File

@ -1,11 +1,9 @@
import datetime
import glob
import os
import random
import numpy as np
from scipy import signal
from trainer.io import save_fsspec
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
@ -136,41 +134,3 @@ def setup_encoder_model(config: "Coqpit"):
audio_config=config.audio,
)
return model
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
checkpoint_path = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -1,38 +0,0 @@
import datetime
import os
from trainer.io import save_fsspec
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
checkpoint_path = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -3,11 +3,11 @@ import unittest
import numpy as np
import torch
from trainer.io import save_checkpoint
from tests import get_tests_input_path
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.managers import EmbeddingManager
from TTS.utils.audio import AudioProcessor
@ -31,7 +31,7 @@ class EmbeddingManagerTest(unittest.TestCase):
# create a dummy speaker encoder
model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0)
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
# load audio processor and speaker encoder
manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)

View File

@ -3,11 +3,11 @@ import unittest
import numpy as np
import torch
from trainer.io import save_checkpoint
from tests import get_tests_input_path
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
@ -30,7 +30,7 @@ class SpeakerManagerTest(unittest.TestCase):
# create a dummy speaker encoder
model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0)
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
# load audio processor and speaker encoder
ap = AudioProcessor(**config.audio)

View File

@ -1,10 +1,11 @@
import os
import unittest
from trainer.io import save_checkpoint
from tests import get_tests_input_path
from TTS.config import load_config
from TTS.tts.models import setup_model
from TTS.utils.io import save_checkpoint
from TTS.utils.synthesizer import Synthesizer