refactor: use copy_model_files() from Trainer

This commit is contained in:
Enno Hermann 2023-11-17 00:12:09 +01:00
parent 5119e651a1
commit 96678c7ba2
3 changed files with 3 additions and 34 deletions

View File

@ -8,6 +8,7 @@ import traceback
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from trainer.io import copy_model_files
from trainer.torch import NoamLR from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer from trainer.trainer_utils import get_optimizer
@ -18,7 +19,6 @@ from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files
from TTS.utils.samplers import PerfectBatchSampler from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update from TTS.utils.training import check_update
@ -276,7 +276,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.loss == "softmaxproto" and c.model != "speaker_encoder": if c.loss == "softmaxproto" and c.model != "speaker_encoder":
c.map_classid_to_classname = map_classid_to_classname c.map_classid_to_classname = map_classid_to_classname
copy_model_files(c, OUT_PATH) copy_model_files(c, OUT_PATH, new_fields={})
if args.restore_path: if args.restore_path:
criterion, args.restore_step = model.load_checkpoint( criterion, args.restore_step = model.load_checkpoint(

View File

@ -3,13 +3,13 @@ from dataclasses import dataclass, field
from coqpit import Coqpit from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint from trainer import TrainerArgs, get_last_checkpoint
from trainer.io import copy_model_files
from trainer.logging import logger_factory from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.tts.utils.text.characters import parse_symbols from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files
@dataclass @dataclass

View File

@ -1,12 +1,9 @@
import json
import os import os
import pickle as pickle_tts import pickle as pickle_tts
import shutil
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Union
import fsspec import fsspec
import torch import torch
from coqpit import Coqpit
from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.generic_utils import get_user_data_dir
@ -27,34 +24,6 @@ class AttrDict(dict):
self.__dict__ = self self.__dict__ = self
def copy_model_files(config: Coqpit, out_path, new_fields=None):
"""Copy config.json and other model files to training folder and add
new fields.
Args:
config (Coqpit): Coqpit config defining the training run.
out_path (str): output path to copy the file.
new_fields (dict): new fileds to be added or edited
in the config file.
"""
copy_config_path = os.path.join(out_path, "config.json")
# add extra information fields
if new_fields:
config.update(new_fields, allow_new=True)
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
json.dump(config.to_dict(), f, indent=4)
# copy model stats file if available
if config.audio.stats_path is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
filesystem = fsspec.get_mapper(copy_stats_path).fs
if not filesystem.exists(copy_stats_path):
with fsspec.open(config.audio.stats_path, "rb") as source_file:
with fsspec.open(copy_stats_path, "wb") as target_file:
shutil.copyfileobj(source_file, target_file)
def load_fsspec( def load_fsspec(
path: str, path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,