mirror of https://github.com/coqui-ai/TTS.git
platform indep. way to fetch user data folder
This commit is contained in:
parent
0117c811a9
commit
4f32e77006
|
@ -3,6 +3,8 @@ import glob
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -67,6 +69,22 @@ def count_parameters(model):
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_data_dir(appname):
|
||||||
|
if sys.platform == "win32":
|
||||||
|
import winreg # pylint: disable=import-outside-toplevel
|
||||||
|
key = winreg.OpenKey(
|
||||||
|
winreg.HKEY_CURRENT_USER,
|
||||||
|
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
||||||
|
)
|
||||||
|
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
|
||||||
|
ans = Path(dir_).resolve(strict=False)
|
||||||
|
elif sys.platform == 'darwin':
|
||||||
|
ans = Path('~/Library/Application Support/').expanduser()
|
||||||
|
else:
|
||||||
|
ans = Path.home().joinpath('.local/share')
|
||||||
|
return ans.joinpath(appname)
|
||||||
|
|
||||||
|
|
||||||
def set_init_dict(model_dict, checkpoint_state, c):
|
def set_init_dict(model_dict, checkpoint_state, c):
|
||||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||||
for k, v in checkpoint_state.items():
|
for k, v in checkpoint_state.items():
|
||||||
|
@ -97,6 +115,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
|
||||||
len(model_dict)))
|
len(model_dict)))
|
||||||
return model_dict
|
return model_dict
|
||||||
|
|
||||||
|
|
||||||
class KeepAverage():
|
class KeepAverage():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.avg_values = {}
|
self.avg_values = {}
|
||||||
|
|
|
@ -4,7 +4,7 @@ from pathlib import Path
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
|
from TTS.utils.generic_utils import get_user_data_dir
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
"""Manage TTS models defined in .models.json.
|
"""Manage TTS models defined in .models.json.
|
||||||
|
@ -19,7 +19,7 @@ class ModelManager(object):
|
||||||
"""
|
"""
|
||||||
def __init__(self, models_file):
|
def __init__(self, models_file):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.output_prefix = os.path.join(str(Path.home()), '.tts')
|
self.output_prefix = get_user_data_dir('tts')
|
||||||
self.url_prefix = "https://drive.google.com/uc?id="
|
self.url_prefix = "https://drive.google.com/uc?id="
|
||||||
self.models_dict = None
|
self.models_dict = None
|
||||||
self.read_models_file(models_file)
|
self.read_models_file(models_file)
|
||||||
|
|
Loading…
Reference in New Issue