mirror of https://github.com/coqui-ai/TTS.git
chore: remove unused TTS.utils.io module
All uses of these methods were replaced with the equivalents from coqui-tts-trainer
This commit is contained in:
parent
e869b9b658
commit
2d06aeb79b
|
@ -29,7 +29,7 @@ def process_args(args, config=None):
|
|||
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
|
||||
Returns:
|
||||
c (TTS.utils.io.AttrDict): Config paramaters.
|
||||
c (Coqpit): Config paramaters.
|
||||
out_path (str): Path to save models and logging.
|
||||
audio_path (str): Path to save generated test audios.
|
||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
import os
|
||||
import pickle as pickle_tts
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
from trainer.io import get_user_data_dir
|
||||
|
||||
|
||||
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||
"""Overload default pickler to solve module renaming problem"""
|
||||
|
||||
def find_class(self, module, name):
|
||||
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
"""A custom dict which converts dict keys
|
||||
to class attributes"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def load_fsspec(
|
||||
path: str,
|
||||
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
||||
cache: bool = True,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
||||
|
||||
Args:
|
||||
path: Any path or url supported by fsspec.
|
||||
map_location: torch.device or str.
|
||||
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
|
||||
**kwargs: Keyword arguments forwarded to torch.load.
|
||||
|
||||
Returns:
|
||||
Object stored in path.
|
||||
"""
|
||||
is_local = os.path.isdir(path) or os.path.isfile(path)
|
||||
if cache and not is_local:
|
||||
with fsspec.open(
|
||||
f"filecache::{path}",
|
||||
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
|
||||
mode="rb",
|
||||
) as f:
|
||||
return torch.load(f, map_location=map_location, **kwargs)
|
||||
else:
|
||||
with fsspec.open(path, "rb") as f:
|
||||
return torch.load(f, map_location=map_location, **kwargs)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
model, checkpoint_path, use_cuda=False, eval=False, cache=False
|
||||
): # pylint: disable=redefined-builtin
|
||||
try:
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
except ModuleNotFoundError:
|
||||
pickle_tts.Unpickler = RenamingUnpickler
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
|
||||
model.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
if eval:
|
||||
model.eval()
|
||||
return model, state
|
|
@ -221,7 +221,7 @@ class GeneratorLoss(nn.Module):
|
|||
changing configurations.
|
||||
|
||||
Args:
|
||||
C (AttrDict): model configuration.
|
||||
C (Coqpit): model configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, C):
|
||||
|
|
Loading…
Reference in New Issue