chore: remove unused TTS.utils.io module

All uses of these methods were replaced with the equivalents from coqui-tts-trainer
This commit is contained in:
Enno Hermann 2024-06-27 11:15:57 +02:00
parent e869b9b658
commit 2d06aeb79b
3 changed files with 2 additions and 71 deletions

View File

@ -29,7 +29,7 @@ def process_args(args, config=None):
args (argparse.Namespace or dict like): Parsed input arguments.
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
Returns:
c (TTS.utils.io.AttrDict): Config paramaters.
c (Coqpit): Config paramaters.
out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does

View File

@ -1,69 +0,0 @@
import os
import pickle as pickle_tts
from typing import Any, Callable, Dict, Union
import fsspec
import torch
from trainer.io import get_user_data_dir
class RenamingUnpickler(pickle_tts.Unpickler):
"""Overload default pickler to solve module renaming problem"""
def find_class(self, module, name):
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
class AttrDict(dict):
"""A custom dict which converts dict keys
to class attributes"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
def load_fsspec(
path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
cache: bool = True,
**kwargs,
) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
Args:
path: Any path or url supported by fsspec.
map_location: torch.device or str.
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
**kwargs: Keyword arguments forwarded to torch.load.
Returns:
Object stored in path.
"""
is_local = os.path.isdir(path) or os.path.isfile(path)
if cache and not is_local:
with fsspec.open(
f"filecache::{path}",
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
mode="rb",
) as f:
return torch.load(f, map_location=map_location, **kwargs)
else:
with fsspec.open(path, "rb") as f:
return torch.load(f, map_location=map_location, **kwargs)
def load_checkpoint(
model, checkpoint_path, use_cuda=False, eval=False, cache=False
): # pylint: disable=redefined-builtin
try:
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
model.load_state_dict(state["model"])
if use_cuda:
model.cuda()
if eval:
model.eval()
return model, state

View File

@ -221,7 +221,7 @@ class GeneratorLoss(nn.Module):
changing configurations.
Args:
C (AttrDict): model configuration.
C (Coqpit): model configuration.
"""
def __init__(self, C):