From 0fb0d67de7bd05ef4afd80f05e242217e9800c80 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 17 Nov 2023 00:39:11 +0100 Subject: [PATCH] refactor: use save_checkpoint()/save_best_model() from Trainer --- TTS/bin/train_encoder.py | 21 +++++++++--- TTS/encoder/utils/generic_utils.py | 40 ----------------------- TTS/encoder/utils/io.py | 38 --------------------- tests/aux_tests/test_embedding_manager.py | 4 +-- tests/aux_tests/test_speaker_manager.py | 4 +-- tests/inference_tests/test_synthesizer.py | 3 +- 6 files changed, 23 insertions(+), 87 deletions(-) delete mode 100644 TTS/encoder/utils/io.py diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index c4fb920f..448fefc7 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -8,12 +8,12 @@ import traceback import torch from torch.utils.data import DataLoader -from trainer.io import copy_model_files +from trainer.io import copy_model_files, save_best_model, save_checkpoint from trainer.torch import NoamLR from trainer.trainer_utils import get_optimizer from TTS.encoder.dataset import EncoderDataset -from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model +from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.encoder.utils.training import init_training from TTS.encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples @@ -222,7 +222,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, if global_step % c.save_step == 0: # save model - save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch) + save_checkpoint( + c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict() + ) end_time = time.time() @@ -245,7 +247,18 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, flush=True, ) # save the best checkpoint - best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch) + best_loss = save_best_model( + eval_loss, + best_loss, + c, + model, + optimizer, + None, + global_step, + epoch, + OUT_PATH, + criterion=criterion.state_dict(), + ) model.train() return best_loss, global_step diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py index 2b003ac8..236d6fe9 100644 --- a/TTS/encoder/utils/generic_utils.py +++ b/TTS/encoder/utils/generic_utils.py @@ -1,11 +1,9 @@ -import datetime import glob import os import random import numpy as np from scipy import signal -from trainer.io import save_fsspec from TTS.encoder.models.lstm import LSTMSpeakerEncoder from TTS.encoder.models.resnet import ResNetSpeakerEncoder @@ -136,41 +134,3 @@ def setup_encoder_model(config: "Coqpit"): audio_config=config.audio, ) return model - - -def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch): - checkpoint_path = "checkpoint_{}.pth".format(current_step) - checkpoint_path = os.path.join(out_path, checkpoint_path) - print(" | | > Checkpoint saving : {}".format(checkpoint_path)) - - new_state_dict = model.state_dict() - state = { - "model": new_state_dict, - "optimizer": optimizer.state_dict() if optimizer is not None else None, - "criterion": criterion.state_dict(), - "step": current_step, - "epoch": epoch, - "loss": model_loss, - "date": datetime.date.today().strftime("%B %d, %Y"), - } - save_fsspec(state, checkpoint_path) - - -def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch): - if model_loss < best_loss: - new_state_dict = model.state_dict() - state = { - "model": new_state_dict, - "optimizer": optimizer.state_dict(), - "criterion": criterion.state_dict(), - "step": current_step, - "epoch": epoch, - "loss": model_loss, - "date": datetime.date.today().strftime("%B %d, %Y"), - } - best_loss = model_loss - bestmodel_path = "best_model.pth" - bestmodel_path = os.path.join(out_path, bestmodel_path) - print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) - save_fsspec(state, bestmodel_path) - return best_loss diff --git a/TTS/encoder/utils/io.py b/TTS/encoder/utils/io.py deleted file mode 100644 index a8359be1..00000000 --- a/TTS/encoder/utils/io.py +++ /dev/null @@ -1,38 +0,0 @@ -import datetime -import os - -from trainer.io import save_fsspec - - -def save_checkpoint(model, optimizer, model_loss, out_path, current_step): - checkpoint_path = "checkpoint_{}.pth".format(current_step) - checkpoint_path = os.path.join(out_path, checkpoint_path) - print(" | | > Checkpoint saving : {}".format(checkpoint_path)) - - new_state_dict = model.state_dict() - state = { - "model": new_state_dict, - "optimizer": optimizer.state_dict() if optimizer is not None else None, - "step": current_step, - "loss": model_loss, - "date": datetime.date.today().strftime("%B %d, %Y"), - } - save_fsspec(state, checkpoint_path) - - -def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step): - if model_loss < best_loss: - new_state_dict = model.state_dict() - state = { - "model": new_state_dict, - "optimizer": optimizer.state_dict(), - "step": current_step, - "loss": model_loss, - "date": datetime.date.today().strftime("%B %d, %Y"), - } - best_loss = model_loss - bestmodel_path = "best_model.pth" - bestmodel_path = os.path.join(out_path, bestmodel_path) - print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) - save_fsspec(state, bestmodel_path) - return best_loss diff --git a/tests/aux_tests/test_embedding_manager.py b/tests/aux_tests/test_embedding_manager.py index 73921501..e3acd62b 100644 --- a/tests/aux_tests/test_embedding_manager.py +++ b/tests/aux_tests/test_embedding_manager.py @@ -3,11 +3,11 @@ import unittest import numpy as np import torch +from trainer.io import save_checkpoint from tests import get_tests_input_path from TTS.config import load_config from TTS.encoder.utils.generic_utils import setup_encoder_model -from TTS.encoder.utils.io import save_checkpoint from TTS.tts.utils.managers import EmbeddingManager from TTS.utils.audio import AudioProcessor @@ -31,7 +31,7 @@ class EmbeddingManagerTest(unittest.TestCase): # create a dummy speaker encoder model = setup_encoder_model(config) - save_checkpoint(model, None, None, get_tests_input_path(), 0) + save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path()) # load audio processor and speaker encoder manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path) diff --git a/tests/aux_tests/test_speaker_manager.py b/tests/aux_tests/test_speaker_manager.py index 397f9c81..402fbca4 100644 --- a/tests/aux_tests/test_speaker_manager.py +++ b/tests/aux_tests/test_speaker_manager.py @@ -3,11 +3,11 @@ import unittest import numpy as np import torch +from trainer.io import save_checkpoint from tests import get_tests_input_path from TTS.config import load_config from TTS.encoder.utils.generic_utils import setup_encoder_model -from TTS.encoder.utils.io import save_checkpoint from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor @@ -30,7 +30,7 @@ class SpeakerManagerTest(unittest.TestCase): # create a dummy speaker encoder model = setup_encoder_model(config) - save_checkpoint(model, None, None, get_tests_input_path(), 0) + save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path()) # load audio processor and speaker encoder ap = AudioProcessor(**config.audio) diff --git a/tests/inference_tests/test_synthesizer.py b/tests/inference_tests/test_synthesizer.py index 40e83017..ce4fc751 100644 --- a/tests/inference_tests/test_synthesizer.py +++ b/tests/inference_tests/test_synthesizer.py @@ -1,10 +1,11 @@ import os import unittest +from trainer.io import save_checkpoint + from tests import get_tests_input_path from TTS.config import load_config from TTS.tts.models import setup_model -from TTS.utils.io import save_checkpoint from TTS.utils.synthesizer import Synthesizer