diff --git a/TTS/encoder/utils/training.py b/TTS/encoder/utils/training.py index f9088003..cc3a78b0 100644 --- a/TTS/encoder/utils/training.py +++ b/TTS/encoder/utils/training.py @@ -29,7 +29,7 @@ def process_args(args, config=None): args (argparse.Namespace or dict like): Parsed input arguments. config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. Returns: - c (TTS.utils.io.AttrDict): Config paramaters. + c (Coqpit): Config paramaters. out_path (str): Path to save models and logging. audio_path (str): Path to save generated test audios. c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does diff --git a/TTS/utils/io.py b/TTS/utils/io.py deleted file mode 100644 index a837f186..00000000 --- a/TTS/utils/io.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -import pickle as pickle_tts -from typing import Any, Callable, Dict, Union - -import fsspec -import torch -from trainer.io import get_user_data_dir - - -class RenamingUnpickler(pickle_tts.Unpickler): - """Overload default pickler to solve module renaming problem""" - - def find_class(self, module, name): - return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) - - -class AttrDict(dict): - """A custom dict which converts dict keys - to class attributes""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.__dict__ = self - - -def load_fsspec( - path: str, - map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, - cache: bool = True, - **kwargs, -) -> Any: - """Like torch.load but can load from other locations (e.g. s3:// , gs://). - - Args: - path: Any path or url supported by fsspec. - map_location: torch.device or str. - cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True. - **kwargs: Keyword arguments forwarded to torch.load. - - Returns: - Object stored in path. - """ - is_local = os.path.isdir(path) or os.path.isfile(path) - if cache and not is_local: - with fsspec.open( - f"filecache::{path}", - filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, - mode="rb", - ) as f: - return torch.load(f, map_location=map_location, **kwargs) - else: - with fsspec.open(path, "rb") as f: - return torch.load(f, map_location=map_location, **kwargs) - - -def load_checkpoint( - model, checkpoint_path, use_cuda=False, eval=False, cache=False -): # pylint: disable=redefined-builtin - try: - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) - except ModuleNotFoundError: - pickle_tts.Unpickler = RenamingUnpickler - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) - model.load_state_dict(state["model"]) - if use_cuda: - model.cuda() - if eval: - model.eval() - return model, state diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index 1f977755..8d4dd725 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -221,7 +221,7 @@ class GeneratorLoss(nn.Module): changing configurations. Args: - C (AttrDict): model configuration. + C (Coqpit): model configuration. """ def __init__(self, C):