mirror of https://github.com/coqui-ai/TTS.git
refactor: use copy_model_files() from Trainer
This commit is contained in:
parent
5119e651a1
commit
96678c7ba2
|
@ -8,6 +8,7 @@ import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from trainer.io import copy_model_files
|
||||||
from trainer.torch import NoamLR
|
from trainer.torch import NoamLR
|
||||||
from trainer.trainer_utils import get_optimizer
|
from trainer.trainer_utils import get_optimizer
|
||||||
|
|
||||||
|
@ -18,7 +19,6 @@ from TTS.encoder.utils.visual import plot_embeddings
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||||
from TTS.utils.io import copy_model_files
|
|
||||||
from TTS.utils.samplers import PerfectBatchSampler
|
from TTS.utils.samplers import PerfectBatchSampler
|
||||||
from TTS.utils.training import check_update
|
from TTS.utils.training import check_update
|
||||||
|
|
||||||
|
@ -276,7 +276,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
if c.loss == "softmaxproto" and c.model != "speaker_encoder":
|
if c.loss == "softmaxproto" and c.model != "speaker_encoder":
|
||||||
c.map_classid_to_classname = map_classid_to_classname
|
c.map_classid_to_classname = map_classid_to_classname
|
||||||
copy_model_files(c, OUT_PATH)
|
copy_model_files(c, OUT_PATH, new_fields={})
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
criterion, args.restore_step = model.load_checkpoint(
|
criterion, args.restore_step = model.load_checkpoint(
|
||||||
|
|
|
@ -3,13 +3,13 @@ from dataclasses import dataclass, field
|
||||||
|
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from trainer import TrainerArgs, get_last_checkpoint
|
from trainer import TrainerArgs, get_last_checkpoint
|
||||||
|
from trainer.io import copy_model_files
|
||||||
from trainer.logging import logger_factory
|
from trainer.logging import logger_factory
|
||||||
from trainer.logging.console_logger import ConsoleLogger
|
from trainer.logging.console_logger import ConsoleLogger
|
||||||
|
|
||||||
from TTS.config import load_config, register_config
|
from TTS.config import load_config, register_config
|
||||||
from TTS.tts.utils.text.characters import parse_symbols
|
from TTS.tts.utils.text.characters import parse_symbols
|
||||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||||
from TTS.utils.io import copy_model_files
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -1,12 +1,9 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import pickle as pickle_tts
|
import pickle as pickle_tts
|
||||||
import shutil
|
|
||||||
from typing import Any, Callable, Dict, Union
|
from typing import Any, Callable, Dict, Union
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
|
||||||
|
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
from TTS.utils.generic_utils import get_user_data_dir
|
||||||
|
|
||||||
|
@ -27,34 +24,6 @@ class AttrDict(dict):
|
||||||
self.__dict__ = self
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
def copy_model_files(config: Coqpit, out_path, new_fields=None):
|
|
||||||
"""Copy config.json and other model files to training folder and add
|
|
||||||
new fields.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Coqpit): Coqpit config defining the training run.
|
|
||||||
out_path (str): output path to copy the file.
|
|
||||||
new_fields (dict): new fileds to be added or edited
|
|
||||||
in the config file.
|
|
||||||
"""
|
|
||||||
copy_config_path = os.path.join(out_path, "config.json")
|
|
||||||
# add extra information fields
|
|
||||||
if new_fields:
|
|
||||||
config.update(new_fields, allow_new=True)
|
|
||||||
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
|
|
||||||
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
|
|
||||||
json.dump(config.to_dict(), f, indent=4)
|
|
||||||
|
|
||||||
# copy model stats file if available
|
|
||||||
if config.audio.stats_path is not None:
|
|
||||||
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
|
||||||
filesystem = fsspec.get_mapper(copy_stats_path).fs
|
|
||||||
if not filesystem.exists(copy_stats_path):
|
|
||||||
with fsspec.open(config.audio.stats_path, "rb") as source_file:
|
|
||||||
with fsspec.open(copy_stats_path, "wb") as target_file:
|
|
||||||
shutil.copyfileobj(source_file, target_file)
|
|
||||||
|
|
||||||
|
|
||||||
def load_fsspec(
|
def load_fsspec(
|
||||||
path: str,
|
path: str,
|
||||||
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
||||||
|
|
Loading…
Reference in New Issue