refactor: use get_user_data_dir from trainer

This commit is contained in:
Enno Hermann 2024-06-27 10:46:15 +02:00
parent 28296c6458
commit 0fb26f97df
8 changed files with 10 additions and 32 deletions

View File

@ -60,6 +60,7 @@ class BaseTrainerModel(TrainerModel):
checkpoint_path (str | os.PathLike): Path to the model checkpoint file. checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False. eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. cache (bool, optional): If True, cache the file locally for subsequent calls.
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
""" """
... ...

View File

@ -2,11 +2,12 @@ import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict from typing import Dict
from trainer.io import get_user_data_dir
from TTS.tts.configs.shared_configs import BaseTTSConfig from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.layers.bark.model import GPTConfig from TTS.tts.layers.bark.model import GPTConfig
from TTS.tts.layers.bark.model_fine import FineGPTConfig from TTS.tts.layers.bark.model_fine import FineGPTConfig
from TTS.tts.models.bark import BarkAudioConfig from TTS.tts.models.bark import BarkAudioConfig
from TTS.utils.generic_utils import get_user_data_dir
@dataclass @dataclass

View File

@ -103,7 +103,8 @@ class BaseTacotron(BaseTTS):
config (Coqpi): model configuration. config (Coqpi): model configuration.
checkpoint_path (str): path to checkpoint file. checkpoint_path (str): path to checkpoint file.
eval (bool, optional): whether to load model for evaluation. eval (bool, optional): whether to load model for evaluation.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. cache (bool, optional): If True, cache the file locally for subsequent calls.
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
""" """
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])

View File

@ -2,9 +2,7 @@
import datetime import datetime
import importlib import importlib
import logging import logging
import os
import re import re
import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Dict, Optional
@ -53,28 +51,6 @@ def get_import_path(obj: object) -> str:
return ".".join([type(obj).__module__, type(obj).__name__]) return ".".join([type(obj).__module__, type(obj).__name__])
def get_user_data_dir(appname):
TTS_HOME = os.environ.get("TTS_HOME")
XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME")
if TTS_HOME is not None:
ans = Path(TTS_HOME).expanduser().resolve(strict=False)
elif XDG_DATA_HOME is not None:
ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False)
elif sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel
key = winreg.OpenKey(
winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
)
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
ans = Path(dir_).resolve(strict=False)
elif sys.platform == "darwin":
ans = Path("~/Library/Application Support/").expanduser()
else:
ans = Path.home().joinpath(".local/share")
return ans.joinpath(appname)
def set_init_dict(model_dict, checkpoint_state, c): def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped. # Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items(): for k, v in checkpoint_state.items():

View File

@ -4,8 +4,7 @@ from typing import Any, Callable, Dict, Union
import fsspec import fsspec
import torch import torch
from trainer.io import get_user_data_dir
from TTS.utils.generic_utils import get_user_data_dir
class RenamingUnpickler(pickle_tts.Unpickler): class RenamingUnpickler(pickle_tts.Unpickler):

View File

@ -11,9 +11,9 @@ from typing import Dict, Tuple
import fsspec import fsspec
import requests import requests
from tqdm import tqdm from tqdm import tqdm
from trainer.io import get_user_data_dir
from TTS.config import load_config, read_json_with_comments from TTS.config import load_config, read_json_with_comments
from TTS.utils.generic_utils import get_user_data_dir
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -3,8 +3,8 @@ import os
import urllib.request import urllib.request
import torch import torch
from trainer.io import get_user_data_dir
from TTS.utils.generic_utils import get_user_data_dir
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -4,11 +4,11 @@ import os
import shutil import shutil
import torch import torch
from trainer.io import get_user_data_dir
from tests import get_tests_data_path, get_tests_output_path, run_cli from tests import get_tests_data_path, get_tests_output_path, run_cli
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.generic_utils import get_user_data_dir
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
MODELS_WITH_SEP_TESTS = [ MODELS_WITH_SEP_TESTS = [