mirror of https://github.com/coqui-ai/TTS.git
feat: allow both Path and strings where possible and add type hints
This commit is contained in:
parent
cd52907351
commit
a425ba599d
|
@ -157,7 +157,7 @@ class TTS(nn.Module):
|
||||||
|
|
||||||
def download_model_by_name(
|
def download_model_by_name(
|
||||||
self, model_name: str, vocoder_name: Optional[str] = None
|
self, model_name: str, vocoder_name: Optional[str] = None
|
||||||
) -> tuple[Optional[str], Optional[str], Optional[str]]:
|
) -> tuple[Optional[Path], Optional[Path], Optional[Path]]:
|
||||||
model_path, config_path, model_item = self.manager.download_model(model_name)
|
model_path, config_path, model_item = self.manager.download_model(model_name)
|
||||||
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
|
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
|
||||||
# return model directory if there are multiple files
|
# return model directory if there are multiple files
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Dict
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -68,7 +68,7 @@ def _process_model_name(config_dict: Dict) -> str:
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path: str) -> Coqpit:
|
def load_config(config_path: Union[str, os.PathLike[Any]]) -> Coqpit:
|
||||||
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
|
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
|
||||||
to find the corresponding Config class. Then initialize the Config.
|
to find the corresponding Config class. Then initialize the Config.
|
||||||
|
|
||||||
|
@ -81,6 +81,7 @@ def load_config(config_path: str) -> Coqpit:
|
||||||
Returns:
|
Returns:
|
||||||
Coqpit: TTS config object.
|
Coqpit: TTS config object.
|
||||||
"""
|
"""
|
||||||
|
config_path = str(config_path)
|
||||||
config_dict = {}
|
config_dict = {}
|
||||||
ext = os.path.splitext(config_path)[1]
|
ext = os.path.splitext(config_path)[1]
|
||||||
if ext in (".yml", ".yaml"):
|
if ext in (".yml", ".yaml"):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -27,8 +27,8 @@ class LanguageManager(BaseIDManager):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
language_ids_file_path: str = "",
|
language_ids_file_path: Union[str, os.PathLike[Any]] = "",
|
||||||
config: Coqpit = None,
|
config: Optional[Coqpit] = None,
|
||||||
):
|
):
|
||||||
super().__init__(id_file_path=language_ids_file_path)
|
super().__init__(id_file_path=language_ids_file_path)
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ class LanguageManager(BaseIDManager):
|
||||||
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
|
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def save_ids_to_file(self, file_path: str) -> None:
|
def save_ids_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
|
||||||
"""Save language IDs to a json file.
|
"""Save language IDs to a json file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
@ -12,7 +13,8 @@ from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||||
|
|
||||||
|
|
||||||
def load_file(path: str):
|
def load_file(path: Union[str, os.PathLike[Any]]):
|
||||||
|
path = str(path)
|
||||||
if path.endswith(".json"):
|
if path.endswith(".json"):
|
||||||
with fsspec.open(path, "r") as f:
|
with fsspec.open(path, "r") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
@ -23,7 +25,8 @@ def load_file(path: str):
|
||||||
raise ValueError("Unsupported file type")
|
raise ValueError("Unsupported file type")
|
||||||
|
|
||||||
|
|
||||||
def save_file(obj: Any, path: str):
|
def save_file(obj: Any, path: Union[str, os.PathLike[Any]]):
|
||||||
|
path = str(path)
|
||||||
if path.endswith(".json"):
|
if path.endswith(".json"):
|
||||||
with fsspec.open(path, "w") as f:
|
with fsspec.open(path, "w") as f:
|
||||||
json.dump(obj, f, indent=4)
|
json.dump(obj, f, indent=4)
|
||||||
|
@ -39,20 +42,20 @@ class BaseIDManager:
|
||||||
It defines common `ID` manager specific functions.
|
It defines common `ID` manager specific functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, id_file_path: str = ""):
|
def __init__(self, id_file_path: Union[str, os.PathLike[Any]] = ""):
|
||||||
self.name_to_id = {}
|
self.name_to_id = {}
|
||||||
|
|
||||||
if id_file_path:
|
if id_file_path:
|
||||||
self.load_ids_from_file(id_file_path)
|
self.load_ids_from_file(id_file_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_json(json_file_path: str) -> Dict:
|
def _load_json(json_file_path: Union[str, os.PathLike[Any]]) -> Dict:
|
||||||
with fsspec.open(json_file_path, "r") as f:
|
with fsspec.open(str(json_file_path), "r") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_json(json_file_path: str, data: dict) -> None:
|
def _save_json(json_file_path: Union[str, os.PathLike[Any]], data: dict) -> None:
|
||||||
with fsspec.open(json_file_path, "w") as f:
|
with fsspec.open(str(json_file_path), "w") as f:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
|
|
||||||
def set_ids_from_data(self, items: List, parse_key: str) -> None:
|
def set_ids_from_data(self, items: List, parse_key: str) -> None:
|
||||||
|
@ -63,7 +66,7 @@ class BaseIDManager:
|
||||||
"""
|
"""
|
||||||
self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key)
|
self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key)
|
||||||
|
|
||||||
def load_ids_from_file(self, file_path: str) -> None:
|
def load_ids_from_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
|
||||||
"""Set IDs from a file.
|
"""Set IDs from a file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -71,7 +74,7 @@ class BaseIDManager:
|
||||||
"""
|
"""
|
||||||
self.name_to_id = load_file(file_path)
|
self.name_to_id = load_file(file_path)
|
||||||
|
|
||||||
def save_ids_to_file(self, file_path: str) -> None:
|
def save_ids_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
|
||||||
"""Save IDs to a json file.
|
"""Save IDs to a json file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -130,10 +133,10 @@ class EmbeddingManager(BaseIDManager):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_file_path: Union[str, List[str]] = "",
|
embedding_file_path: Union[Union[str, os.PathLike[Any]], list[Union[str, os.PathLike[Any]]]] = "",
|
||||||
id_file_path: str = "",
|
id_file_path: Union[str, os.PathLike[Any]] = "",
|
||||||
encoder_model_path: str = "",
|
encoder_model_path: Union[str, os.PathLike[Any]] = "",
|
||||||
encoder_config_path: str = "",
|
encoder_config_path: Union[str, os.PathLike[Any]] = "",
|
||||||
use_cuda: bool = False,
|
use_cuda: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(id_file_path=id_file_path)
|
super().__init__(id_file_path=id_file_path)
|
||||||
|
@ -176,7 +179,7 @@ class EmbeddingManager(BaseIDManager):
|
||||||
"""Get embedding names."""
|
"""Get embedding names."""
|
||||||
return list(self.embeddings_by_names.keys())
|
return list(self.embeddings_by_names.keys())
|
||||||
|
|
||||||
def save_embeddings_to_file(self, file_path: str) -> None:
|
def save_embeddings_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
|
||||||
"""Save embeddings to a json file.
|
"""Save embeddings to a json file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -185,7 +188,7 @@ class EmbeddingManager(BaseIDManager):
|
||||||
save_file(self.embeddings, file_path)
|
save_file(self.embeddings, file_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def read_embeddings_from_file(file_path: str):
|
def read_embeddings_from_file(file_path: Union[str, os.PathLike[Any]]):
|
||||||
"""Load embeddings from a json file.
|
"""Load embeddings from a json file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -204,7 +207,7 @@ class EmbeddingManager(BaseIDManager):
|
||||||
embeddings_by_names[x["name"]].append(x["embedding"])
|
embeddings_by_names[x["name"]].append(x["embedding"])
|
||||||
return name_to_id, clip_ids, embeddings, embeddings_by_names
|
return name_to_id, clip_ids, embeddings, embeddings_by_names
|
||||||
|
|
||||||
def load_embeddings_from_file(self, file_path: str) -> None:
|
def load_embeddings_from_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
|
||||||
"""Load embeddings from a json file.
|
"""Load embeddings from a json file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -214,7 +217,7 @@ class EmbeddingManager(BaseIDManager):
|
||||||
file_path
|
file_path
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_embeddings_from_list_of_files(self, file_paths: List[str]) -> None:
|
def load_embeddings_from_list_of_files(self, file_paths: list[Union[str, os.PathLike[Any]]]) -> None:
|
||||||
"""Load embeddings from a list of json files and don't allow duplicate keys.
|
"""Load embeddings from a list of json files and don't allow duplicate keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -313,7 +316,9 @@ class EmbeddingManager(BaseIDManager):
|
||||||
def get_clips(self) -> List:
|
def get_clips(self) -> List:
|
||||||
return sorted(self.embeddings.keys())
|
return sorted(self.embeddings.keys())
|
||||||
|
|
||||||
def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> None:
|
def init_encoder(
|
||||||
|
self, model_path: Union[str, os.PathLike[Any]], config_path: Union[str, os.PathLike[Any]], use_cuda=False
|
||||||
|
) -> None:
|
||||||
"""Initialize a speaker encoder model.
|
"""Initialize a speaker encoder model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -325,11 +330,13 @@ class EmbeddingManager(BaseIDManager):
|
||||||
self.encoder_config = load_config(config_path)
|
self.encoder_config = load_config(config_path)
|
||||||
self.encoder = setup_encoder_model(self.encoder_config)
|
self.encoder = setup_encoder_model(self.encoder_config)
|
||||||
self.encoder_criterion = self.encoder.load_checkpoint(
|
self.encoder_criterion = self.encoder.load_checkpoint(
|
||||||
self.encoder_config, model_path, eval=True, use_cuda=use_cuda, cache=True
|
self.encoder_config, str(model_path), eval=True, use_cuda=use_cuda, cache=True
|
||||||
)
|
)
|
||||||
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
|
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
|
||||||
|
|
||||||
def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
def compute_embedding_from_clip(
|
||||||
|
self, wav_file: Union[Union[str, os.PathLike[Any]], List[Union[str, os.PathLike[Any]]]]
|
||||||
|
) -> list:
|
||||||
"""Compute a embedding from a given audio file.
|
"""Compute a embedding from a given audio file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -56,11 +56,11 @@ class SpeakerManager(EmbeddingManager):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_items: List[List[Any]] = None,
|
data_items: Optional[list[list[Any]]] = None,
|
||||||
d_vectors_file_path: str = "",
|
d_vectors_file_path: str = "",
|
||||||
speaker_id_file_path: str = "",
|
speaker_id_file_path: Union[str, os.PathLike[Any]] = "",
|
||||||
encoder_model_path: str = "",
|
encoder_model_path: Union[str, os.PathLike[Any]] = "",
|
||||||
encoder_config_path: str = "",
|
encoder_config_path: Union[str, os.PathLike[Any]] = "",
|
||||||
use_cuda: bool = False,
|
use_cuda: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -406,7 +407,9 @@ def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.n
|
||||||
return rms_norm(wav=x, db_level=db_level)
|
return rms_norm(wav=x, db_level=db_level)
|
||||||
|
|
||||||
|
|
||||||
def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool = False, **kwargs) -> np.ndarray:
|
def load_wav(
|
||||||
|
*, filename: Union[str, os.PathLike[Any]], sample_rate: Optional[int] = None, resample: bool = False, **kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||||
|
|
||||||
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
||||||
|
@ -434,7 +437,7 @@ def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool
|
||||||
def save_wav(
|
def save_wav(
|
||||||
*,
|
*,
|
||||||
wav: np.ndarray,
|
wav: np.ndarray,
|
||||||
path: str,
|
path: Union[str, os.PathLike[Any]],
|
||||||
sample_rate: int,
|
sample_rate: int,
|
||||||
pipe_out=None,
|
pipe_out=None,
|
||||||
do_rms_norm: bool = False,
|
do_rms_norm: bool = False,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
import os
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -548,7 +549,7 @@ class AudioProcessor:
|
||||||
return volume_norm(x=x)
|
return volume_norm(x=x)
|
||||||
|
|
||||||
### save and load ###
|
### save and load ###
|
||||||
def load_wav(self, filename: str, sr: Optional[int] = None) -> np.ndarray:
|
def load_wav(self, filename: Union[str, os.PathLike[Any]], sr: Optional[int] = None) -> np.ndarray:
|
||||||
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||||
|
|
||||||
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
||||||
|
@ -575,7 +576,9 @@ class AudioProcessor:
|
||||||
x = rms_volume_norm(x=x, db_level=self.db_level)
|
x = rms_volume_norm(x=x, db_level=self.db_level)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def save_wav(self, wav: np.ndarray, path: str, sr: Optional[int] = None, pipe_out=None) -> None:
|
def save_wav(
|
||||||
|
self, wav: np.ndarray, path: Union[str, os.PathLike[Any]], sr: Optional[int] = None, pipe_out=None
|
||||||
|
) -> None:
|
||||||
"""Save a waveform to a file using Scipy.
|
"""Save a waveform to a file using Scipy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -4,7 +4,7 @@ import importlib
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Optional, TypeVar, Union
|
from typing import Any, Callable, Dict, Optional, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
@ -133,3 +133,8 @@ def setup_logger(
|
||||||
def is_pytorch_at_least_2_4() -> bool:
|
def is_pytorch_at_least_2_4() -> bool:
|
||||||
"""Check if the installed Pytorch version is 2.4 or higher."""
|
"""Check if the installed Pytorch version is 2.4 or higher."""
|
||||||
return Version(torch.__version__) >= Version("2.4")
|
return Version(torch.__version__) >= Version("2.4")
|
||||||
|
|
||||||
|
|
||||||
|
def optional_to_str(x: Optional[Any]) -> str:
|
||||||
|
"""Convert input to string, using empty string if input is None."""
|
||||||
|
return "" if x is None else str(x)
|
||||||
|
|
|
@ -6,17 +6,35 @@ import tarfile
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile, rmtree
|
from shutil import copyfile, rmtree
|
||||||
from typing import Dict, Tuple
|
from typing import Any, Optional, TypedDict, Union
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import requests
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from trainer.io import get_user_data_dir
|
from trainer.io import get_user_data_dir
|
||||||
|
from typing_extensions import Required
|
||||||
|
|
||||||
from TTS.config import load_config, read_json_with_comments
|
from TTS.config import load_config, read_json_with_comments
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelItem(TypedDict, total=False):
|
||||||
|
model_name: Required[str]
|
||||||
|
model_type: Required[str]
|
||||||
|
description: str
|
||||||
|
license: str
|
||||||
|
author: str
|
||||||
|
contact: str
|
||||||
|
commit: Optional[str]
|
||||||
|
model_hash: str
|
||||||
|
tos_required: bool
|
||||||
|
default_vocoder: Optional[str]
|
||||||
|
model_url: Union[str, list[str]]
|
||||||
|
github_rls_url: Union[str, list[str]]
|
||||||
|
hf_url: list[str]
|
||||||
|
|
||||||
|
|
||||||
LICENSE_URLS = {
|
LICENSE_URLS = {
|
||||||
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
|
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
|
||||||
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
|
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
|
||||||
|
@ -40,19 +58,24 @@ class ModelManager(object):
|
||||||
home path.
|
home path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
models_file (str): path to .model.json file. Defaults to None.
|
models_file (str or Path): path to .model.json file. Defaults to None.
|
||||||
output_prefix (str): prefix to `tts` to download models. Defaults to None
|
output_prefix (str or Path): prefix to `tts` to download models. Defaults to None
|
||||||
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
|
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, models_file=None, output_prefix=None, progress_bar=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
models_file: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
|
output_prefix: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
|
progress_bar: bool = False,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.progress_bar = progress_bar
|
self.progress_bar = progress_bar
|
||||||
if output_prefix is None:
|
if output_prefix is None:
|
||||||
self.output_prefix = get_user_data_dir("tts")
|
self.output_prefix = get_user_data_dir("tts")
|
||||||
else:
|
else:
|
||||||
self.output_prefix = os.path.join(output_prefix, "tts")
|
self.output_prefix = Path(output_prefix) / "tts"
|
||||||
self.models_dict = None
|
self.models_dict = {}
|
||||||
if models_file is not None:
|
if models_file is not None:
|
||||||
self.read_models_file(models_file)
|
self.read_models_file(models_file)
|
||||||
else:
|
else:
|
||||||
|
@ -60,7 +83,7 @@ class ModelManager(object):
|
||||||
path = Path(__file__).parent / "../.models.json"
|
path = Path(__file__).parent / "../.models.json"
|
||||||
self.read_models_file(path)
|
self.read_models_file(path)
|
||||||
|
|
||||||
def read_models_file(self, file_path):
|
def read_models_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
|
||||||
"""Read .models.json as a dict
|
"""Read .models.json as a dict
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -68,7 +91,7 @@ class ModelManager(object):
|
||||||
"""
|
"""
|
||||||
self.models_dict = read_json_with_comments(file_path)
|
self.models_dict = read_json_with_comments(file_path)
|
||||||
|
|
||||||
def _list_models(self, model_type, model_count=0):
|
def _list_models(self, model_type: str, model_count: int = 0) -> list[str]:
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("Name format: type/language/dataset/model")
|
logger.info("Name format: type/language/dataset/model")
|
||||||
model_list = []
|
model_list = []
|
||||||
|
@ -83,13 +106,13 @@ class ModelManager(object):
|
||||||
model_count += 1
|
model_count += 1
|
||||||
return model_list
|
return model_list
|
||||||
|
|
||||||
def _list_for_model_type(self, model_type):
|
def _list_for_model_type(self, model_type: str) -> list[str]:
|
||||||
models_name_list = []
|
models_name_list = []
|
||||||
model_count = 1
|
model_count = 1
|
||||||
models_name_list.extend(self._list_models(model_type, model_count))
|
models_name_list.extend(self._list_models(model_type, model_count))
|
||||||
return models_name_list
|
return models_name_list
|
||||||
|
|
||||||
def list_models(self):
|
def list_models(self) -> list[str]:
|
||||||
models_name_list = []
|
models_name_list = []
|
||||||
model_count = 1
|
model_count = 1
|
||||||
for model_type in self.models_dict:
|
for model_type in self.models_dict:
|
||||||
|
@ -97,7 +120,7 @@ class ModelManager(object):
|
||||||
models_name_list.extend(model_list)
|
models_name_list.extend(model_list)
|
||||||
return models_name_list
|
return models_name_list
|
||||||
|
|
||||||
def log_model_details(self, model_type, lang, dataset, model):
|
def log_model_details(self, model_type: str, lang: str, dataset: str, model: str) -> None:
|
||||||
logger.info("Model type: %s", model_type)
|
logger.info("Model type: %s", model_type)
|
||||||
logger.info("Language supported: %s", lang)
|
logger.info("Language supported: %s", lang)
|
||||||
logger.info("Dataset used: %s", dataset)
|
logger.info("Dataset used: %s", dataset)
|
||||||
|
@ -112,7 +135,7 @@ class ModelManager(object):
|
||||||
self.models_dict[model_type][lang][dataset][model]["default_vocoder"],
|
self.models_dict[model_type][lang][dataset][model]["default_vocoder"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_info_by_idx(self, model_query):
|
def model_info_by_idx(self, model_query: str) -> None:
|
||||||
"""Print the description of the model from .models.json file using model_query_idx
|
"""Print the description of the model from .models.json file using model_query_idx
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -144,7 +167,7 @@ class ModelManager(object):
|
||||||
model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
|
model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
|
||||||
self.log_model_details(model_type, lang, dataset, model)
|
self.log_model_details(model_type, lang, dataset, model)
|
||||||
|
|
||||||
def model_info_by_full_name(self, model_query_name):
|
def model_info_by_full_name(self, model_query_name: str) -> None:
|
||||||
"""Print the description of the model from .models.json file using model_full_name
|
"""Print the description of the model from .models.json file using model_full_name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -165,35 +188,35 @@ class ModelManager(object):
|
||||||
return
|
return
|
||||||
self.log_model_details(model_type, lang, dataset, model)
|
self.log_model_details(model_type, lang, dataset, model)
|
||||||
|
|
||||||
def list_tts_models(self):
|
def list_tts_models(self) -> list[str]:
|
||||||
"""Print all `TTS` models and return a list of model names
|
"""Print all `TTS` models and return a list of model names
|
||||||
|
|
||||||
Format is `language/dataset/model`
|
Format is `language/dataset/model`
|
||||||
"""
|
"""
|
||||||
return self._list_for_model_type("tts_models")
|
return self._list_for_model_type("tts_models")
|
||||||
|
|
||||||
def list_vocoder_models(self):
|
def list_vocoder_models(self) -> list[str]:
|
||||||
"""Print all the `vocoder` models and return a list of model names
|
"""Print all the `vocoder` models and return a list of model names
|
||||||
|
|
||||||
Format is `language/dataset/model`
|
Format is `language/dataset/model`
|
||||||
"""
|
"""
|
||||||
return self._list_for_model_type("vocoder_models")
|
return self._list_for_model_type("vocoder_models")
|
||||||
|
|
||||||
def list_vc_models(self):
|
def list_vc_models(self) -> list[str]:
|
||||||
"""Print all the voice conversion models and return a list of model names
|
"""Print all the voice conversion models and return a list of model names
|
||||||
|
|
||||||
Format is `language/dataset/model`
|
Format is `language/dataset/model`
|
||||||
"""
|
"""
|
||||||
return self._list_for_model_type("voice_conversion_models")
|
return self._list_for_model_type("voice_conversion_models")
|
||||||
|
|
||||||
def list_langs(self):
|
def list_langs(self) -> None:
|
||||||
"""Print all the available languages"""
|
"""Print all the available languages"""
|
||||||
logger.info("Name format: type/language")
|
logger.info("Name format: type/language")
|
||||||
for model_type in self.models_dict:
|
for model_type in self.models_dict:
|
||||||
for lang in self.models_dict[model_type]:
|
for lang in self.models_dict[model_type]:
|
||||||
logger.info(" %s/%s", model_type, lang)
|
logger.info(" %s/%s", model_type, lang)
|
||||||
|
|
||||||
def list_datasets(self):
|
def list_datasets(self) -> None:
|
||||||
"""Print all the datasets"""
|
"""Print all the datasets"""
|
||||||
logger.info("Name format: type/language/dataset")
|
logger.info("Name format: type/language/dataset")
|
||||||
for model_type in self.models_dict:
|
for model_type in self.models_dict:
|
||||||
|
@ -202,7 +225,7 @@ class ModelManager(object):
|
||||||
logger.info(" %s/%s/%s", model_type, lang, dataset)
|
logger.info(" %s/%s/%s", model_type, lang, dataset)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def print_model_license(model_item: Dict):
|
def print_model_license(model_item: ModelItem) -> None:
|
||||||
"""Print the license of a model
|
"""Print the license of a model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -217,27 +240,27 @@ class ModelManager(object):
|
||||||
else:
|
else:
|
||||||
logger.info("Model's license - No license information available")
|
logger.info("Model's license - No license information available")
|
||||||
|
|
||||||
def _download_github_model(self, model_item: Dict, output_path: str):
|
def _download_github_model(self, model_item: ModelItem, output_path: Path) -> None:
|
||||||
if isinstance(model_item["github_rls_url"], list):
|
if isinstance(model_item["github_rls_url"], list):
|
||||||
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
|
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||||
else:
|
else:
|
||||||
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||||
|
|
||||||
def _download_hf_model(self, model_item: Dict, output_path: str):
|
def _download_hf_model(self, model_item: ModelItem, output_path: Path) -> None:
|
||||||
if isinstance(model_item["hf_url"], list):
|
if isinstance(model_item["hf_url"], list):
|
||||||
self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
|
self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
|
||||||
else:
|
else:
|
||||||
self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)
|
self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)
|
||||||
|
|
||||||
def download_fairseq_model(self, model_name, output_path):
|
def download_fairseq_model(self, model_name: str, output_path: Path) -> None:
|
||||||
URI_PREFIX = "https://dl.fbaipublicfiles.com/mms/tts/"
|
URI_PREFIX = "https://dl.fbaipublicfiles.com/mms/tts/"
|
||||||
_, lang, _, _ = model_name.split("/")
|
_, lang, _, _ = model_name.split("/")
|
||||||
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
|
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
|
||||||
self._download_tar_file(model_download_uri, output_path, self.progress_bar)
|
self._download_tar_file(model_download_uri, output_path, self.progress_bar)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_model_url(model_item: Dict):
|
def set_model_url(model_item: ModelItem) -> ModelItem:
|
||||||
model_item["model_url"] = None
|
model_item["model_url"] = ""
|
||||||
if "github_rls_url" in model_item:
|
if "github_rls_url" in model_item:
|
||||||
model_item["model_url"] = model_item["github_rls_url"]
|
model_item["model_url"] = model_item["github_rls_url"]
|
||||||
elif "hf_url" in model_item:
|
elif "hf_url" in model_item:
|
||||||
|
@ -248,18 +271,18 @@ class ModelManager(object):
|
||||||
model_item["model_url"] = "https://huggingface.co/coqui/"
|
model_item["model_url"] = "https://huggingface.co/coqui/"
|
||||||
return model_item
|
return model_item
|
||||||
|
|
||||||
def _set_model_item(self, model_name):
|
def _set_model_item(self, model_name: str) -> tuple[ModelItem, str, str, Optional[str]]:
|
||||||
# fetch model info from the dict
|
# fetch model info from the dict
|
||||||
if "fairseq" in model_name:
|
if "fairseq" in model_name:
|
||||||
model_type, lang, dataset, model = model_name.split("/")
|
model_type, lang, dataset, model = model_name.split("/")
|
||||||
model_item = {
|
model_item: ModelItem = {
|
||||||
|
"model_name": model_name,
|
||||||
"model_type": "tts_models",
|
"model_type": "tts_models",
|
||||||
"license": "CC BY-NC 4.0",
|
"license": "CC BY-NC 4.0",
|
||||||
"default_vocoder": None,
|
"default_vocoder": None,
|
||||||
"author": "fairseq",
|
"author": "fairseq",
|
||||||
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
|
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
|
||||||
}
|
}
|
||||||
model_item["model_name"] = model_name
|
|
||||||
elif "xtts" in model_name and len(model_name.split("/")) != 4:
|
elif "xtts" in model_name and len(model_name.split("/")) != 4:
|
||||||
# loading xtts models with only model name (e.g. xtts_v2.0.2)
|
# loading xtts models with only model name (e.g. xtts_v2.0.2)
|
||||||
# check model name has the version number with regex
|
# check model name has the version number with regex
|
||||||
|
@ -273,6 +296,8 @@ class ModelManager(object):
|
||||||
dataset = "multi-dataset"
|
dataset = "multi-dataset"
|
||||||
model = model_name
|
model = model_name
|
||||||
model_item = {
|
model_item = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"model_type": model_type,
|
||||||
"default_vocoder": None,
|
"default_vocoder": None,
|
||||||
"license": "CPML",
|
"license": "CPML",
|
||||||
"contact": "info@coqui.ai",
|
"contact": "info@coqui.ai",
|
||||||
|
@ -297,9 +322,9 @@ class ModelManager(object):
|
||||||
return model_item, model_full_name, model, md5hash
|
return model_item, model_full_name, model, md5hash
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ask_tos(model_full_path):
|
def ask_tos(model_full_path: Path) -> bool:
|
||||||
"""Ask the user to agree to the terms of service"""
|
"""Ask the user to agree to the terms of service"""
|
||||||
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
tos_path = model_full_path / "tos_agreed.txt"
|
||||||
print(" > You must confirm the following:")
|
print(" > You must confirm the following:")
|
||||||
print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"')
|
print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"')
|
||||||
print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]')
|
print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]')
|
||||||
|
@ -311,7 +336,7 @@ class ModelManager(object):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tos_agreed(model_item, model_full_path):
|
def tos_agreed(model_item: ModelItem, model_full_path: Path) -> bool:
|
||||||
"""Check if the user has agreed to the terms of service"""
|
"""Check if the user has agreed to the terms of service"""
|
||||||
if "tos_required" in model_item and model_item["tos_required"]:
|
if "tos_required" in model_item and model_item["tos_required"]:
|
||||||
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
||||||
|
@ -320,12 +345,12 @@ class ModelManager(object):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def create_dir_and_download_model(self, model_name, model_item, output_path):
|
def create_dir_and_download_model(self, model_name: str, model_item: ModelItem, output_path: Path) -> None:
|
||||||
os.makedirs(output_path, exist_ok=True)
|
output_path.mkdir(exist_ok=True, parents=True)
|
||||||
# handle TOS
|
# handle TOS
|
||||||
if not self.tos_agreed(model_item, output_path):
|
if not self.tos_agreed(model_item, output_path):
|
||||||
if not self.ask_tos(output_path):
|
if not self.ask_tos(output_path):
|
||||||
os.rmdir(output_path)
|
output_path.rmdir()
|
||||||
raise Exception(" [!] You must agree to the terms of service to use this model.")
|
raise Exception(" [!] You must agree to the terms of service to use this model.")
|
||||||
logger.info("Downloading model to %s", output_path)
|
logger.info("Downloading model to %s", output_path)
|
||||||
try:
|
try:
|
||||||
|
@ -342,7 +367,7 @@ class ModelManager(object):
|
||||||
raise e
|
raise e
|
||||||
self.print_model_license(model_item=model_item)
|
self.print_model_license(model_item=model_item)
|
||||||
|
|
||||||
def check_if_configs_are_equal(self, model_name, model_item, output_path):
|
def check_if_configs_are_equal(self, model_name: str, model_item: ModelItem, output_path: Path) -> None:
|
||||||
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
|
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
|
||||||
config_local = json.load(f)
|
config_local = json.load(f)
|
||||||
remote_url = None
|
remote_url = None
|
||||||
|
@ -358,7 +383,7 @@ class ModelManager(object):
|
||||||
logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name)
|
logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name)
|
||||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||||
|
|
||||||
def download_model(self, model_name):
|
def download_model(self, model_name: str) -> tuple[Path, Optional[Path], ModelItem]:
|
||||||
"""Download model files given the full model name.
|
"""Download model files given the full model name.
|
||||||
Model name is in the format
|
Model name is in the format
|
||||||
'type/language/dataset/model'
|
'type/language/dataset/model'
|
||||||
|
@ -374,12 +399,12 @@ class ModelManager(object):
|
||||||
"""
|
"""
|
||||||
model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
|
model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
|
||||||
# set the model specific output path
|
# set the model specific output path
|
||||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
output_path = Path(self.output_prefix) / model_full_name
|
||||||
if os.path.exists(output_path):
|
if output_path.is_dir():
|
||||||
if md5sum is not None:
|
if md5sum is not None:
|
||||||
md5sum_file = os.path.join(output_path, "hash.md5")
|
md5sum_file = output_path / "hash.md5"
|
||||||
if os.path.isfile(md5sum_file):
|
if md5sum_file.is_file():
|
||||||
with open(md5sum_file, mode="r") as f:
|
with md5sum_file.open() as f:
|
||||||
if not f.read() == md5sum:
|
if not f.read() == md5sum:
|
||||||
logger.info("%s has been updated, clearing model cache...", model_name)
|
logger.info("%s has been updated, clearing model cache...", model_name)
|
||||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||||
|
@ -407,12 +432,14 @@ class ModelManager(object):
|
||||||
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
|
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
|
||||||
): # TODO:This is stupid but don't care for now.
|
): # TODO:This is stupid but don't care for now.
|
||||||
output_model_path, output_config_path = self._find_files(output_path)
|
output_model_path, output_config_path = self._find_files(output_path)
|
||||||
|
else:
|
||||||
|
output_config_path = output_model_path / "config.json"
|
||||||
# update paths in the config.json
|
# update paths in the config.json
|
||||||
self._update_paths(output_path, output_config_path)
|
self._update_paths(output_path, output_config_path)
|
||||||
return output_model_path, output_config_path, model_item
|
return output_model_path, output_config_path, model_item
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _find_files(output_path: str) -> Tuple[str, str]:
|
def _find_files(output_path: Path) -> tuple[Path, Path]:
|
||||||
"""Find the model and config files in the output path
|
"""Find the model and config files in the output path
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -423,11 +450,11 @@ class ModelManager(object):
|
||||||
"""
|
"""
|
||||||
model_file = None
|
model_file = None
|
||||||
config_file = None
|
config_file = None
|
||||||
for file_name in os.listdir(output_path):
|
for f in output_path.iterdir():
|
||||||
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]:
|
if f.name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]:
|
||||||
model_file = os.path.join(output_path, file_name)
|
model_file = f
|
||||||
elif file_name == "config.json":
|
elif f.name == "config.json":
|
||||||
config_file = os.path.join(output_path, file_name)
|
config_file = f
|
||||||
if model_file is None:
|
if model_file is None:
|
||||||
raise ValueError(" [!] Model file not found in the output path")
|
raise ValueError(" [!] Model file not found in the output path")
|
||||||
if config_file is None:
|
if config_file is None:
|
||||||
|
@ -435,7 +462,7 @@ class ModelManager(object):
|
||||||
return model_file, config_file
|
return model_file, config_file
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _find_speaker_encoder(output_path: str) -> str:
|
def _find_speaker_encoder(output_path: Path) -> Optional[Path]:
|
||||||
"""Find the speaker encoder file in the output path
|
"""Find the speaker encoder file in the output path
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -445,24 +472,24 @@ class ModelManager(object):
|
||||||
str: path to the speaker encoder file
|
str: path to the speaker encoder file
|
||||||
"""
|
"""
|
||||||
speaker_encoder_file = None
|
speaker_encoder_file = None
|
||||||
for file_name in os.listdir(output_path):
|
for f in output_path.iterdir():
|
||||||
if file_name in ["model_se.pth", "model_se.pth.tar"]:
|
if f.name in ["model_se.pth", "model_se.pth.tar"]:
|
||||||
speaker_encoder_file = os.path.join(output_path, file_name)
|
speaker_encoder_file = f
|
||||||
return speaker_encoder_file
|
return speaker_encoder_file
|
||||||
|
|
||||||
def _update_paths(self, output_path: str, config_path: str) -> None:
|
def _update_paths(self, output_path: Path, config_path: Path) -> None:
|
||||||
"""Update paths for certain files in config.json after download.
|
"""Update paths for certain files in config.json after download.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_path (str): local path the model is downloaded to.
|
output_path (str): local path the model is downloaded to.
|
||||||
config_path (str): local config.json path.
|
config_path (str): local config.json path.
|
||||||
"""
|
"""
|
||||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
output_stats_path = output_path / "scale_stats.npy"
|
||||||
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
output_d_vector_file_path = output_path / "speakers.json"
|
||||||
output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth")
|
output_d_vector_file_pth_path = output_path / "speakers.pth"
|
||||||
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
output_speaker_ids_file_path = output_path / "speaker_ids.json"
|
||||||
output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth")
|
output_speaker_ids_file_pth_path = output_path / "speaker_ids.pth"
|
||||||
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
|
speaker_encoder_config_path = output_path / "config_se.json"
|
||||||
speaker_encoder_model_path = self._find_speaker_encoder(output_path)
|
speaker_encoder_model_path = self._find_speaker_encoder(output_path)
|
||||||
|
|
||||||
# update the scale_path.npy file path in the model config.json
|
# update the scale_path.npy file path in the model config.json
|
||||||
|
@ -487,10 +514,10 @@ class ModelManager(object):
|
||||||
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_path(field_name, new_path, config_path):
|
def _update_path(field_name: str, new_path: Optional[Path], config_path: Path) -> None:
|
||||||
"""Update the path in the model config.json for the current environment after download"""
|
"""Update the path in the model config.json for the current environment after download"""
|
||||||
if new_path and os.path.exists(new_path):
|
if new_path is not None and new_path.is_file():
|
||||||
config = load_config(config_path)
|
config = load_config(str(config_path))
|
||||||
field_names = field_name.split(".")
|
field_names = field_name.split(".")
|
||||||
if len(field_names) > 1:
|
if len(field_names) > 1:
|
||||||
# field name points to a sub-level field
|
# field name points to a sub-level field
|
||||||
|
@ -515,7 +542,7 @@ class ModelManager(object):
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _download_zip_file(file_url, output_folder, progress_bar):
|
def _download_zip_file(file_url: str, output_folder: Path, progress_bar: bool) -> None:
|
||||||
"""Download the github releases"""
|
"""Download the github releases"""
|
||||||
# download the file
|
# download the file
|
||||||
r = requests.get(file_url, stream=True)
|
r = requests.get(file_url, stream=True)
|
||||||
|
@ -525,7 +552,7 @@ class ModelManager(object):
|
||||||
block_size = 1024 # 1 Kibibyte
|
block_size = 1024 # 1 Kibibyte
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
|
temp_zip_name = output_folder / file_url.split("/")[-1]
|
||||||
with open(temp_zip_name, "wb") as file:
|
with open(temp_zip_name, "wb") as file:
|
||||||
for data in r.iter_content(block_size):
|
for data in r.iter_content(block_size):
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
|
@ -533,24 +560,24 @@ class ModelManager(object):
|
||||||
file.write(data)
|
file.write(data)
|
||||||
with zipfile.ZipFile(temp_zip_name) as z:
|
with zipfile.ZipFile(temp_zip_name) as z:
|
||||||
z.extractall(output_folder)
|
z.extractall(output_folder)
|
||||||
os.remove(temp_zip_name) # delete zip after extract
|
temp_zip_name.unlink() # delete zip after extract
|
||||||
except zipfile.BadZipFile:
|
except zipfile.BadZipFile:
|
||||||
logger.exception("Bad zip file - %s", file_url)
|
logger.exception("Bad zip file - %s", file_url)
|
||||||
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
|
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
|
||||||
# move the files to the outer path
|
# move the files to the outer path
|
||||||
for file_path in z.namelist():
|
for file_path in z.namelist():
|
||||||
src_path = os.path.join(output_folder, file_path)
|
src_path = output_folder / file_path
|
||||||
if os.path.isfile(src_path):
|
if src_path.is_file():
|
||||||
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
dst_path = output_folder / os.path.basename(file_path)
|
||||||
if src_path != dst_path:
|
if src_path != dst_path:
|
||||||
copyfile(src_path, dst_path)
|
copyfile(src_path, dst_path)
|
||||||
# remove redundant (hidden or not) folders
|
# remove redundant (hidden or not) folders
|
||||||
for file_path in z.namelist():
|
for file_path in z.namelist():
|
||||||
if os.path.isdir(os.path.join(output_folder, file_path)):
|
if (output_folder / file_path).is_dir():
|
||||||
rmtree(os.path.join(output_folder, file_path))
|
rmtree(output_folder / file_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _download_tar_file(file_url, output_folder, progress_bar):
|
def _download_tar_file(file_url: str, output_folder: Path, progress_bar: bool) -> None:
|
||||||
"""Download the github releases"""
|
"""Download the github releases"""
|
||||||
# download the file
|
# download the file
|
||||||
r = requests.get(file_url, stream=True)
|
r = requests.get(file_url, stream=True)
|
||||||
|
@ -560,7 +587,7 @@ class ModelManager(object):
|
||||||
block_size = 1024 # 1 Kibibyte
|
block_size = 1024 # 1 Kibibyte
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
|
temp_tar_name = output_folder / file_url.split("/")[-1]
|
||||||
with open(temp_tar_name, "wb") as file:
|
with open(temp_tar_name, "wb") as file:
|
||||||
for data in r.iter_content(block_size):
|
for data in r.iter_content(block_size):
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
|
@ -569,43 +596,37 @@ class ModelManager(object):
|
||||||
with tarfile.open(temp_tar_name) as t:
|
with tarfile.open(temp_tar_name) as t:
|
||||||
t.extractall(output_folder)
|
t.extractall(output_folder)
|
||||||
tar_names = t.getnames()
|
tar_names = t.getnames()
|
||||||
os.remove(temp_tar_name) # delete tar after extract
|
temp_tar_name.unlink() # delete tar after extract
|
||||||
except tarfile.ReadError:
|
except tarfile.ReadError:
|
||||||
logger.exception("Bad tar file - %s", file_url)
|
logger.exception("Bad tar file - %s", file_url)
|
||||||
raise tarfile.ReadError # pylint: disable=raise-missing-from
|
raise tarfile.ReadError # pylint: disable=raise-missing-from
|
||||||
# move the files to the outer path
|
# move the files to the outer path
|
||||||
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
|
for file_path in (output_folder / tar_names[0]).iterdir():
|
||||||
src_path = os.path.join(output_folder, tar_names[0], file_path)
|
src_path = file_path
|
||||||
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
dst_path = output_folder / file_path.name
|
||||||
if src_path != dst_path:
|
if src_path != dst_path:
|
||||||
copyfile(src_path, dst_path)
|
copyfile(src_path, dst_path)
|
||||||
# remove the extracted folder
|
# remove the extracted folder
|
||||||
rmtree(os.path.join(output_folder, tar_names[0]))
|
rmtree(output_folder / tar_names[0])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _download_model_files(file_urls, output_folder, progress_bar):
|
def _download_model_files(
|
||||||
|
file_urls: list[str], output_folder: Union[str, os.PathLike[Any]], progress_bar: bool
|
||||||
|
) -> None:
|
||||||
"""Download the github releases"""
|
"""Download the github releases"""
|
||||||
|
output_folder = Path(output_folder)
|
||||||
for file_url in file_urls:
|
for file_url in file_urls:
|
||||||
# download the file
|
# download the file
|
||||||
r = requests.get(file_url, stream=True)
|
r = requests.get(file_url, stream=True)
|
||||||
# extract the file
|
# extract the file
|
||||||
bease_filename = file_url.split("/")[-1]
|
base_filename = file_url.split("/")[-1]
|
||||||
temp_zip_name = os.path.join(output_folder, bease_filename)
|
file_path = output_folder / base_filename
|
||||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||||
block_size = 1024 # 1 Kibibyte
|
block_size = 1024 # 1 Kibibyte
|
||||||
with open(temp_zip_name, "wb") as file:
|
with open(file_path, "wb") as f:
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
for data in r.iter_content(block_size):
|
for data in r.iter_content(block_size):
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
ModelManager.tqdm_progress.update(len(data))
|
ModelManager.tqdm_progress.update(len(data))
|
||||||
file.write(data)
|
f.write(data)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _check_dict_key(my_dict, key):
|
|
||||||
if key in my_dict.keys() and my_dict[key] is not None:
|
|
||||||
if not isinstance(key, str):
|
|
||||||
return True
|
|
||||||
if isinstance(key, str) and len(my_dict[key]) > 0:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pysbd
|
import pysbd
|
||||||
|
@ -16,6 +16,7 @@ from TTS.tts.models.vits import Vits
|
||||||
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
|
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.audio.numpy_transforms import save_wav
|
from TTS.utils.audio.numpy_transforms import save_wav
|
||||||
|
from TTS.utils.generic_utils import optional_to_str
|
||||||
from TTS.vc.configs.openvoice_config import OpenVoiceConfig
|
from TTS.vc.configs.openvoice_config import OpenVoiceConfig
|
||||||
from TTS.vc.models import setup_model as setup_vc_model
|
from TTS.vc.models import setup_model as setup_vc_model
|
||||||
from TTS.vc.models.openvoice import OpenVoice
|
from TTS.vc.models.openvoice import OpenVoice
|
||||||
|
@ -29,18 +30,18 @@ class Synthesizer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
tts_checkpoint: str = "",
|
tts_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
tts_config_path: str = "",
|
tts_config_path: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
tts_speakers_file: str = "",
|
tts_speakers_file: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
tts_languages_file: str = "",
|
tts_languages_file: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
vocoder_checkpoint: str = "",
|
vocoder_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
vocoder_config: str = "",
|
vocoder_config: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
encoder_checkpoint: str = "",
|
encoder_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
encoder_config: str = "",
|
encoder_config: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
vc_checkpoint: str = "",
|
vc_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
vc_config: str = "",
|
vc_config: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
model_dir: str = "",
|
model_dir: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
voice_dir: str = None,
|
voice_dir: Optional[Union[str, os.PathLike[Any]]] = None,
|
||||||
use_cuda: bool = False,
|
use_cuda: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
||||||
|
@ -66,16 +67,17 @@ class Synthesizer(nn.Module):
|
||||||
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tts_checkpoint = tts_checkpoint
|
self.tts_checkpoint = optional_to_str(tts_checkpoint)
|
||||||
self.tts_config_path = tts_config_path
|
self.tts_config_path = optional_to_str(tts_config_path)
|
||||||
self.tts_speakers_file = tts_speakers_file
|
self.tts_speakers_file = optional_to_str(tts_speakers_file)
|
||||||
self.tts_languages_file = tts_languages_file
|
self.tts_languages_file = optional_to_str(tts_languages_file)
|
||||||
self.vocoder_checkpoint = vocoder_checkpoint
|
self.vocoder_checkpoint = optional_to_str(vocoder_checkpoint)
|
||||||
self.vocoder_config = vocoder_config
|
self.vocoder_config = optional_to_str(vocoder_config)
|
||||||
self.encoder_checkpoint = encoder_checkpoint
|
self.encoder_checkpoint = optional_to_str(encoder_checkpoint)
|
||||||
self.encoder_config = encoder_config
|
self.encoder_config = optional_to_str(encoder_config)
|
||||||
self.vc_checkpoint = vc_checkpoint
|
self.vc_checkpoint = optional_to_str(vc_checkpoint)
|
||||||
self.vc_config = vc_config
|
self.vc_config = optional_to_str(vc_config)
|
||||||
|
model_dir = optional_to_str(model_dir)
|
||||||
self.use_cuda = use_cuda
|
self.use_cuda = use_cuda
|
||||||
|
|
||||||
self.tts_model = None
|
self.tts_model = None
|
||||||
|
@ -89,18 +91,18 @@ class Synthesizer(nn.Module):
|
||||||
self.d_vector_dim = 0
|
self.d_vector_dim = 0
|
||||||
self.seg = self._get_segmenter("en")
|
self.seg = self._get_segmenter("en")
|
||||||
self.use_cuda = use_cuda
|
self.use_cuda = use_cuda
|
||||||
self.voice_dir = voice_dir
|
self.voice_dir = optional_to_str(voice_dir)
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
||||||
|
|
||||||
if tts_checkpoint:
|
if tts_checkpoint:
|
||||||
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
|
self._load_tts(self.tts_checkpoint, self.tts_config_path, use_cuda)
|
||||||
|
|
||||||
if vocoder_checkpoint:
|
if vocoder_checkpoint:
|
||||||
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
|
self._load_vocoder(self.vocoder_checkpoint, self.vocoder_config, use_cuda)
|
||||||
|
|
||||||
if vc_checkpoint and model_dir is None:
|
if vc_checkpoint and model_dir == "":
|
||||||
self._load_vc(vc_checkpoint, vc_config, use_cuda)
|
self._load_vc(self.vc_checkpoint, self.vc_config, use_cuda)
|
||||||
|
|
||||||
if model_dir:
|
if model_dir:
|
||||||
if "fairseq" in model_dir:
|
if "fairseq" in model_dir:
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
#!/usr/bin/env python3`
|
#!/usr/bin/env python3`
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
@ -30,22 +29,22 @@ def run_models(offset=0, step=1):
|
||||||
print(f"\n > Run - {model_name}")
|
print(f"\n > Run - {model_name}")
|
||||||
model_path, _, _ = manager.download_model(model_name)
|
model_path, _, _ = manager.download_model(model_name)
|
||||||
if "tts_models" in model_name:
|
if "tts_models" in model_name:
|
||||||
local_download_dir = os.path.dirname(model_path)
|
local_download_dir = model_path.parent
|
||||||
# download and run the model
|
# download and run the model
|
||||||
speaker_files = glob.glob(local_download_dir + "/speaker*")
|
speaker_files = list(local_download_dir.glob("speaker*"))
|
||||||
language_files = glob.glob(local_download_dir + "/language*")
|
language_files = list(local_download_dir.glob("language*"))
|
||||||
speaker_arg = ""
|
speaker_arg = ""
|
||||||
language_arg = ""
|
language_arg = ""
|
||||||
if len(speaker_files) > 0:
|
if len(speaker_files) > 0:
|
||||||
# multi-speaker model
|
# multi-speaker model
|
||||||
if "speaker_ids" in speaker_files[0]:
|
if "speaker_ids" in speaker_files[0].stem:
|
||||||
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
|
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
|
||||||
elif "speakers" in speaker_files[0]:
|
elif "speakers" in speaker_files[0].stem:
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
|
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
|
||||||
speakers = list(speaker_manager.name_to_id.keys())
|
speakers = list(speaker_manager.name_to_id.keys())
|
||||||
if len(speakers) > 1:
|
if len(speakers) > 1:
|
||||||
speaker_arg = f'--speaker_idx "{speakers[0]}"'
|
speaker_arg = f'--speaker_idx "{speakers[0]}"'
|
||||||
if len(language_files) > 0 and "language_ids" in language_files[0]:
|
if len(language_files) > 0 and "language_ids" in language_files[0].stem:
|
||||||
# multi-lingual model
|
# multi-lingual model
|
||||||
language_manager = LanguageManager(language_ids_file_path=language_files[0])
|
language_manager = LanguageManager(language_ids_file_path=language_files[0])
|
||||||
languages = language_manager.language_names
|
languages = language_manager.language_names
|
||||||
|
|
Loading…
Reference in New Issue