mirror of https://github.com/coqui-ai/TTS.git
refactor: use get_user_data_dir from trainer
This commit is contained in:
parent
28296c6458
commit
0fb26f97df
|
@ -60,6 +60,7 @@ class BaseTrainerModel(TrainerModel):
|
||||||
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
|
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
|
||||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||||
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
cache (bool, optional): If True, cache the file locally for subsequent calls.
|
||||||
|
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -2,11 +2,12 @@ import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
from trainer.io import get_user_data_dir
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
from TTS.tts.layers.bark.model import GPTConfig
|
from TTS.tts.layers.bark.model import GPTConfig
|
||||||
from TTS.tts.layers.bark.model_fine import FineGPTConfig
|
from TTS.tts.layers.bark.model_fine import FineGPTConfig
|
||||||
from TTS.tts.models.bark import BarkAudioConfig
|
from TTS.tts.models.bark import BarkAudioConfig
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -103,7 +103,8 @@ class BaseTacotron(BaseTTS):
|
||||||
config (Coqpi): model configuration.
|
config (Coqpi): model configuration.
|
||||||
checkpoint_path (str): path to checkpoint file.
|
checkpoint_path (str): path to checkpoint file.
|
||||||
eval (bool, optional): whether to load model for evaluation.
|
eval (bool, optional): whether to load model for evaluation.
|
||||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
cache (bool, optional): If True, cache the file locally for subsequent calls.
|
||||||
|
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
|
||||||
"""
|
"""
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
|
|
|
@ -2,9 +2,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
@ -53,28 +51,6 @@ def get_import_path(obj: object) -> str:
|
||||||
return ".".join([type(obj).__module__, type(obj).__name__])
|
return ".".join([type(obj).__module__, type(obj).__name__])
|
||||||
|
|
||||||
|
|
||||||
def get_user_data_dir(appname):
|
|
||||||
TTS_HOME = os.environ.get("TTS_HOME")
|
|
||||||
XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME")
|
|
||||||
if TTS_HOME is not None:
|
|
||||||
ans = Path(TTS_HOME).expanduser().resolve(strict=False)
|
|
||||||
elif XDG_DATA_HOME is not None:
|
|
||||||
ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False)
|
|
||||||
elif sys.platform == "win32":
|
|
||||||
import winreg # pylint: disable=import-outside-toplevel
|
|
||||||
|
|
||||||
key = winreg.OpenKey(
|
|
||||||
winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
|
||||||
)
|
|
||||||
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
|
|
||||||
ans = Path(dir_).resolve(strict=False)
|
|
||||||
elif sys.platform == "darwin":
|
|
||||||
ans = Path("~/Library/Application Support/").expanduser()
|
|
||||||
else:
|
|
||||||
ans = Path.home().joinpath(".local/share")
|
|
||||||
return ans.joinpath(appname)
|
|
||||||
|
|
||||||
|
|
||||||
def set_init_dict(model_dict, checkpoint_state, c):
|
def set_init_dict(model_dict, checkpoint_state, c):
|
||||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||||
for k, v in checkpoint_state.items():
|
for k, v in checkpoint_state.items():
|
||||||
|
|
|
@ -4,8 +4,7 @@ from typing import Any, Callable, Dict, Union
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
|
from trainer.io import get_user_data_dir
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
|
||||||
|
|
||||||
|
|
||||||
class RenamingUnpickler(pickle_tts.Unpickler):
|
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||||
|
|
|
@ -11,9 +11,9 @@ from typing import Dict, Tuple
|
||||||
import fsspec
|
import fsspec
|
||||||
import requests
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from trainer.io import get_user_data_dir
|
||||||
|
|
||||||
from TTS.config import load_config, read_json_with_comments
|
from TTS.config import load_config, read_json_with_comments
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,8 @@ import os
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from trainer.io import get_user_data_dir
|
||||||
|
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
|
||||||
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
|
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
@ -4,11 +4,11 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from trainer.io import get_user_data_dir
|
||||||
|
|
||||||
from tests import get_tests_data_path, get_tests_output_path, run_cli
|
from tests import get_tests_data_path, get_tests_output_path, run_cli
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
MODELS_WITH_SEP_TESTS = [
|
MODELS_WITH_SEP_TESTS = [
|
||||||
|
|
Loading…
Reference in New Issue