Merge pull request #3243 from idiap/checkpoints

Remove duplicate/unused code
This commit is contained in:
Eren Gölge 2023-11-22 23:52:06 +01:00 committed by GitHub
commit b47d9c6e36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 25 additions and 241 deletions

View File

@ -8,17 +8,17 @@ import traceback
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from trainer.io import copy_model_files, save_best_model, save_checkpoint
from trainer.torch import NoamLR from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer from trainer.trainer_utils import get_optimizer
from TTS.encoder.dataset import EncoderDataset from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.training import init_training from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files
from TTS.utils.samplers import PerfectBatchSampler from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update from TTS.utils.training import check_update
@ -222,7 +222,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
if global_step % c.save_step == 0: if global_step % c.save_step == 0:
# save model # save model
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch) save_checkpoint(
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
)
end_time = time.time() end_time = time.time()
@ -245,7 +247,18 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
flush=True, flush=True,
) )
# save the best checkpoint # save the best checkpoint
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch) best_loss = save_best_model(
eval_loss,
best_loss,
c,
model,
optimizer,
None,
global_step,
epoch,
OUT_PATH,
criterion=criterion.state_dict(),
)
model.train() model.train()
return best_loss, global_step return best_loss, global_step
@ -276,7 +289,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.loss == "softmaxproto" and c.model != "speaker_encoder": if c.loss == "softmaxproto" and c.model != "speaker_encoder":
c.map_classid_to_classname = map_classid_to_classname c.map_classid_to_classname = map_classid_to_classname
copy_model_files(c, OUT_PATH) copy_model_files(c, OUT_PATH, new_fields={})
if args.restore_path: if args.restore_path:
criterion, args.restore_step = model.load_checkpoint( criterion, args.restore_step = model.load_checkpoint(

View File

@ -1,15 +1,12 @@
import datetime
import glob import glob
import os import os
import random import random
import re
import numpy as np import numpy as np
from scipy import signal from scipy import signal
from TTS.encoder.models.lstm import LSTMSpeakerEncoder from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder from TTS.encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec
class AugmentWAV(object): class AugmentWAV(object):
@ -118,11 +115,6 @@ class AugmentWAV(object):
return self.additive_noise(noise_type, audio) return self.additive_noise(noise_type, audio)
def to_camel(text):
text = text.capitalize()
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_encoder_model(config: "Coqpit"): def setup_encoder_model(config: "Coqpit"):
if config.model_params["model_name"].lower() == "lstm": if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder( model = LSTMSpeakerEncoder(
@ -142,41 +134,3 @@ def setup_encoder_model(config: "Coqpit"):
audio_config=config.audio, audio_config=config.audio,
) )
return model return model
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
checkpoint_path = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -1,38 +0,0 @@
import datetime
import os
from TTS.utils.io import save_fsspec
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
checkpoint_path = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -3,13 +3,13 @@ from dataclasses import dataclass, field
from coqpit import Coqpit from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint from trainer import TrainerArgs, get_last_checkpoint
from trainer.io import copy_model_files
from trainer.logging import logger_factory from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.tts.utils.text.characters import parse_symbols from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files
@dataclass @dataclass

View File

@ -1,13 +1,9 @@
import datetime
import json
import os import os
import pickle as pickle_tts import pickle as pickle_tts
import shutil
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Union
import fsspec import fsspec
import torch import torch
from coqpit import Coqpit
from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.generic_utils import get_user_data_dir
@ -28,34 +24,6 @@ class AttrDict(dict):
self.__dict__ = self self.__dict__ = self
def copy_model_files(config: Coqpit, out_path, new_fields=None):
"""Copy config.json and other model files to training folder and add
new fields.
Args:
config (Coqpit): Coqpit config defining the training run.
out_path (str): output path to copy the file.
new_fields (dict): new fileds to be added or edited
in the config file.
"""
copy_config_path = os.path.join(out_path, "config.json")
# add extra information fields
if new_fields:
config.update(new_fields, allow_new=True)
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
json.dump(config.to_dict(), f, indent=4)
# copy model stats file if available
if config.audio.stats_path is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
filesystem = fsspec.get_mapper(copy_stats_path).fs
if not filesystem.exists(copy_stats_path):
with fsspec.open(config.audio.stats_path, "rb") as source_file:
with fsspec.open(copy_stats_path, "wb") as target_file:
shutil.copyfileobj(source_file, target_file)
def load_fsspec( def load_fsspec(
path: str, path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
@ -100,117 +68,3 @@ def load_checkpoint(
if eval: if eval:
model.eval() model.eval()
return model, state return model, state
def save_fsspec(state: Any, path: str, **kwargs):
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
Args:
state: State object to save
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.save.
"""
with fsspec.open(path, "wb") as f:
torch.save(state, f, **kwargs)
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
if isinstance(optimizer, list):
optimizer_state = [optim.state_dict() for optim in optimizer]
elif optimizer.__class__.__name__ == "CapacitronOptimizer":
optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()]
else:
optimizer_state = optimizer.state_dict() if optimizer is not None else None
if isinstance(scaler, list):
scaler_state = [s.state_dict() for s in scaler]
else:
scaler_state = scaler.state_dict() if scaler is not None else None
if isinstance(config, Coqpit):
config = config.to_dict()
state = {
"config": config,
"model": model_state,
"optimizer": optimizer_state,
"scaler": scaler_state,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
save_fsspec(state, output_path)
def save_checkpoint(
config,
model,
optimizer,
scaler,
current_step,
epoch,
output_folder,
**kwargs,
):
file_name = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(output_folder, file_name)
print("\n > CHECKPOINT : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
**kwargs,
)
def save_best_model(
current_loss,
best_loss,
config,
model,
optimizer,
scaler,
current_step,
epoch,
out_path,
keep_all_best=False,
keep_after=10000,
**kwargs,
):
if current_loss < best_loss:
best_model_name = f"best_model_{current_step}.pth"
checkpoint_path = os.path.join(out_path, best_model_name)
print(" > BEST MODEL : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
model_loss=current_loss,
**kwargs,
)
fs = fsspec.get_mapper(out_path).fs
# only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after):
model_names = fs.glob(os.path.join(out_path, "best_model*.pth"))
for model_name in model_names:
if os.path.basename(model_name) != best_model_name:
fs.rm(model_name)
# create a shortcut which always points to the currently best model
shortcut_name = "best_model.pth"
shortcut_path = os.path.join(out_path, shortcut_name)
fs.copy(checkpoint_path, shortcut_path)
best_loss = current_loss
return best_loss

View File

@ -3,11 +3,11 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from trainer.io import save_checkpoint
from tests import get_tests_input_path from tests import get_tests_input_path
from TTS.config import load_config from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.managers import EmbeddingManager from TTS.tts.utils.managers import EmbeddingManager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -31,7 +31,7 @@ class EmbeddingManagerTest(unittest.TestCase):
# create a dummy speaker encoder # create a dummy speaker encoder
model = setup_encoder_model(config) model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0) save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
# load audio processor and speaker encoder # load audio processor and speaker encoder
manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path) manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)

View File

@ -3,11 +3,11 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from trainer.io import save_checkpoint
from tests import get_tests_input_path from tests import get_tests_input_path
from TTS.config import load_config from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -30,7 +30,7 @@ class SpeakerManagerTest(unittest.TestCase):
# create a dummy speaker encoder # create a dummy speaker encoder
model = setup_encoder_model(config) model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0) save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
# load audio processor and speaker encoder # load audio processor and speaker encoder
ap = AudioProcessor(**config.audio) ap = AudioProcessor(**config.audio)

View File

@ -1,10 +1,11 @@
import os import os
import unittest import unittest
from trainer.io import save_checkpoint
from tests import get_tests_input_path from tests import get_tests_input_path
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.utils.io import save_checkpoint
from TTS.utils.synthesizer import Synthesizer from TTS.utils.synthesizer import Synthesizer