refactor: use save_checkpoint()/save_best_model() from Trainer

This commit is contained in:
Enno Hermann 2023-11-17 00:39:11 +01:00
parent 96678c7ba2
commit 0fb0d67de7
6 changed files with 23 additions and 87 deletions

View File

@ -8,12 +8,12 @@ import traceback
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from trainer.io import copy_model_files from trainer.io import copy_model_files, save_best_model, save_checkpoint
from trainer.torch import NoamLR from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer from trainer.trainer_utils import get_optimizer
from TTS.encoder.dataset import EncoderDataset from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.training import init_training from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
@ -222,7 +222,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
if global_step % c.save_step == 0: if global_step % c.save_step == 0:
# save model # save model
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch) save_checkpoint(
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
)
end_time = time.time() end_time = time.time()
@ -245,7 +247,18 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
flush=True, flush=True,
) )
# save the best checkpoint # save the best checkpoint
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch) best_loss = save_best_model(
eval_loss,
best_loss,
c,
model,
optimizer,
None,
global_step,
epoch,
OUT_PATH,
criterion=criterion.state_dict(),
)
model.train() model.train()
return best_loss, global_step return best_loss, global_step

View File

@ -1,11 +1,9 @@
import datetime
import glob import glob
import os import os
import random import random
import numpy as np import numpy as np
from scipy import signal from scipy import signal
from trainer.io import save_fsspec
from TTS.encoder.models.lstm import LSTMSpeakerEncoder from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder from TTS.encoder.models.resnet import ResNetSpeakerEncoder
@ -136,41 +134,3 @@ def setup_encoder_model(config: "Coqpit"):
audio_config=config.audio, audio_config=config.audio,
) )
return model return model
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
checkpoint_path = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -1,38 +0,0 @@
import datetime
import os
from trainer.io import save_fsspec
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
checkpoint_path = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -3,11 +3,11 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from trainer.io import save_checkpoint
from tests import get_tests_input_path from tests import get_tests_input_path
from TTS.config import load_config from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.managers import EmbeddingManager from TTS.tts.utils.managers import EmbeddingManager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -31,7 +31,7 @@ class EmbeddingManagerTest(unittest.TestCase):
# create a dummy speaker encoder # create a dummy speaker encoder
model = setup_encoder_model(config) model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0) save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
# load audio processor and speaker encoder # load audio processor and speaker encoder
manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path) manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)

View File

@ -3,11 +3,11 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from trainer.io import save_checkpoint
from tests import get_tests_input_path from tests import get_tests_input_path
from TTS.config import load_config from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -30,7 +30,7 @@ class SpeakerManagerTest(unittest.TestCase):
# create a dummy speaker encoder # create a dummy speaker encoder
model = setup_encoder_model(config) model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0) save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
# load audio processor and speaker encoder # load audio processor and speaker encoder
ap = AudioProcessor(**config.audio) ap = AudioProcessor(**config.audio)

View File

@ -1,10 +1,11 @@
import os import os
import unittest import unittest
from trainer.io import save_checkpoint
from tests import get_tests_input_path from tests import get_tests_input_path
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.utils.io import save_checkpoint
from TTS.utils.synthesizer import Synthesizer from TTS.utils.synthesizer import Synthesizer