mirror of https://github.com/coqui-ai/TTS.git
refactor: use save_checkpoint()/save_best_model() from Trainer
This commit is contained in:
parent
96678c7ba2
commit
0fb0d67de7
|
@ -8,12 +8,12 @@ import traceback
|
|||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from trainer.io import copy_model_files
|
||||
from trainer.io import copy_model_files, save_best_model, save_checkpoint
|
||||
from trainer.torch import NoamLR
|
||||
from trainer.trainer_utils import get_optimizer
|
||||
|
||||
from TTS.encoder.dataset import EncoderDataset
|
||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.encoder.utils.training import init_training
|
||||
from TTS.encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
|
@ -222,7 +222,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
|||
|
||||
if global_step % c.save_step == 0:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch)
|
||||
save_checkpoint(
|
||||
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
|
@ -245,7 +247,18 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
|||
flush=True,
|
||||
)
|
||||
# save the best checkpoint
|
||||
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
|
||||
best_loss = save_best_model(
|
||||
eval_loss,
|
||||
best_loss,
|
||||
c,
|
||||
model,
|
||||
optimizer,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
criterion=criterion.state_dict(),
|
||||
)
|
||||
model.train()
|
||||
|
||||
return best_loss, global_step
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
import datetime
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from trainer.io import save_fsspec
|
||||
|
||||
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||
|
@ -136,41 +134,3 @@ def setup_encoder_model(config: "Coqpit"):
|
|||
audio_config=config.audio,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
|
||||
checkpoint_path = "checkpoint_{}.pth".format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
||||
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"criterion": criterion.state_dict(),
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
save_fsspec(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
|
||||
if model_loss < best_loss:
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"criterion": criterion.state_dict(),
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
best_loss = model_loss
|
||||
bestmodel_path = "best_model.pth"
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||
save_fsspec(state, bestmodel_path)
|
||||
return best_loss
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
import datetime
|
||||
import os
|
||||
|
||||
from trainer.io import save_fsspec
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
||||
checkpoint_path = "checkpoint_{}.pth".format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
||||
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"step": current_step,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
save_fsspec(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
|
||||
if model_loss < best_loss:
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"step": current_step,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
best_loss = model_loss
|
||||
bestmodel_path = "best_model.pth"
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||
save_fsspec(state, bestmodel_path)
|
||||
return best_loss
|
|
@ -3,11 +3,11 @@ import unittest
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from trainer.io import save_checkpoint
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.encoder.utils.io import save_checkpoint
|
||||
from TTS.tts.utils.managers import EmbeddingManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -31,7 +31,7 @@ class EmbeddingManagerTest(unittest.TestCase):
|
|||
|
||||
# create a dummy speaker encoder
|
||||
model = setup_encoder_model(config)
|
||||
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
||||
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
|
||||
|
||||
# load audio processor and speaker encoder
|
||||
manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
|
||||
|
|
|
@ -3,11 +3,11 @@ import unittest
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from trainer.io import save_checkpoint
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.encoder.utils.io import save_checkpoint
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -30,7 +30,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
|||
|
||||
# create a dummy speaker encoder
|
||||
model = setup_encoder_model(config)
|
||||
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
||||
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
|
||||
|
||||
# load audio processor and speaker encoder
|
||||
ap = AudioProcessor(**config.audio)
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
from trainer.io import save_checkpoint
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.utils.io import save_checkpoint
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue