From 96678c7ba227871d0929f2366d083219ccfa9262 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 17 Nov 2023 00:12:09 +0100 Subject: [PATCH] refactor: use copy_model_files() from Trainer --- TTS/bin/train_encoder.py | 4 ++-- TTS/encoder/utils/training.py | 2 +- TTS/utils/io.py | 31 ------------------------------- 3 files changed, 3 insertions(+), 34 deletions(-) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index f2e7779c..c4fb920f 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -8,6 +8,7 @@ import traceback import torch from torch.utils.data import DataLoader +from trainer.io import copy_model_files from trainer.torch import NoamLR from trainer.trainer_utils import get_optimizer @@ -18,7 +19,6 @@ from TTS.encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters, remove_experiment_folder -from TTS.utils.io import copy_model_files from TTS.utils.samplers import PerfectBatchSampler from TTS.utils.training import check_update @@ -276,7 +276,7 @@ def main(args): # pylint: disable=redefined-outer-name if c.loss == "softmaxproto" and c.model != "speaker_encoder": c.map_classid_to_classname = map_classid_to_classname - copy_model_files(c, OUT_PATH) + copy_model_files(c, OUT_PATH, new_fields={}) if args.restore_path: criterion, args.restore_step = model.load_checkpoint( diff --git a/TTS/encoder/utils/training.py b/TTS/encoder/utils/training.py index 7c58a232..ff8f271d 100644 --- a/TTS/encoder/utils/training.py +++ b/TTS/encoder/utils/training.py @@ -3,13 +3,13 @@ from dataclasses import dataclass, field from coqpit import Coqpit from trainer import TrainerArgs, get_last_checkpoint +from trainer.io import copy_model_files from trainer.logging import logger_factory from trainer.logging.console_logger import ConsoleLogger from TTS.config import load_config, register_config from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch -from TTS.utils.io import copy_model_files @dataclass diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 7aaedbe2..3107ba66 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -1,12 +1,9 @@ -import json import os import pickle as pickle_tts -import shutil from typing import Any, Callable, Dict, Union import fsspec import torch -from coqpit import Coqpit from TTS.utils.generic_utils import get_user_data_dir @@ -27,34 +24,6 @@ class AttrDict(dict): self.__dict__ = self -def copy_model_files(config: Coqpit, out_path, new_fields=None): - """Copy config.json and other model files to training folder and add - new fields. - - Args: - config (Coqpit): Coqpit config defining the training run. - out_path (str): output path to copy the file. - new_fields (dict): new fileds to be added or edited - in the config file. - """ - copy_config_path = os.path.join(out_path, "config.json") - # add extra information fields - if new_fields: - config.update(new_fields, allow_new=True) - # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths. - with fsspec.open(copy_config_path, "w", encoding="utf8") as f: - json.dump(config.to_dict(), f, indent=4) - - # copy model stats file if available - if config.audio.stats_path is not None: - copy_stats_path = os.path.join(out_path, "scale_stats.npy") - filesystem = fsspec.get_mapper(copy_stats_path).fs - if not filesystem.exists(copy_stats_path): - with fsspec.open(config.audio.stats_path, "rb") as source_file: - with fsspec.open(copy_stats_path, "wb") as target_file: - shutil.copyfileobj(source_file, target_file) - - def load_fsspec( path: str, map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,