feat: allow both Path and strings where possible and add type hints

This commit is contained in:
Enno Hermann 2024-12-14 00:28:01 +01:00
parent cd52907351
commit a425ba599d
11 changed files with 204 additions and 163 deletions

View File

@ -157,7 +157,7 @@ class TTS(nn.Module):
def download_model_by_name( def download_model_by_name(
self, model_name: str, vocoder_name: Optional[str] = None self, model_name: str, vocoder_name: Optional[str] = None
) -> tuple[Optional[str], Optional[str], Optional[str]]: ) -> tuple[Optional[Path], Optional[Path], Optional[Path]]:
model_path, config_path, model_item = self.manager.download_model(model_name) model_path, config_path, model_item = self.manager.download_model(model_name)
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)): if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
# return model directory if there are multiple files # return model directory if there are multiple files

View File

@ -1,7 +1,7 @@
import json import json
import os import os
import re import re
from typing import Dict from typing import Any, Dict, Union
import fsspec import fsspec
import yaml import yaml
@ -68,7 +68,7 @@ def _process_model_name(config_dict: Dict) -> str:
return model_name return model_name
def load_config(config_path: str) -> Coqpit: def load_config(config_path: Union[str, os.PathLike[Any]]) -> Coqpit:
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
to find the corresponding Config class. Then initialize the Config. to find the corresponding Config class. Then initialize the Config.
@ -81,6 +81,7 @@ def load_config(config_path: str) -> Coqpit:
Returns: Returns:
Coqpit: TTS config object. Coqpit: TTS config object.
""" """
config_path = str(config_path)
config_dict = {} config_dict = {}
ext = os.path.splitext(config_path)[1] ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"): if ext in (".yml", ".yaml"):

View File

@ -1,5 +1,5 @@
import os import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
import fsspec import fsspec
import numpy as np import numpy as np
@ -27,8 +27,8 @@ class LanguageManager(BaseIDManager):
def __init__( def __init__(
self, self,
language_ids_file_path: str = "", language_ids_file_path: Union[str, os.PathLike[Any]] = "",
config: Coqpit = None, config: Optional[Coqpit] = None,
): ):
super().__init__(id_file_path=language_ids_file_path) super().__init__(id_file_path=language_ids_file_path)
@ -76,7 +76,7 @@ class LanguageManager(BaseIDManager):
def set_ids_from_data(self, items: List, parse_key: str) -> Any: def set_ids_from_data(self, items: List, parse_key: str) -> Any:
raise NotImplementedError raise NotImplementedError
def save_ids_to_file(self, file_path: str) -> None: def save_ids_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Save language IDs to a json file. """Save language IDs to a json file.
Args: Args:

View File

@ -1,4 +1,5 @@
import json import json
import os
import random import random
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
@ -12,7 +13,8 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import is_pytorch_at_least_2_4 from TTS.utils.generic_utils import is_pytorch_at_least_2_4
def load_file(path: str): def load_file(path: Union[str, os.PathLike[Any]]):
path = str(path)
if path.endswith(".json"): if path.endswith(".json"):
with fsspec.open(path, "r") as f: with fsspec.open(path, "r") as f:
return json.load(f) return json.load(f)
@ -23,7 +25,8 @@ def load_file(path: str):
raise ValueError("Unsupported file type") raise ValueError("Unsupported file type")
def save_file(obj: Any, path: str): def save_file(obj: Any, path: Union[str, os.PathLike[Any]]):
path = str(path)
if path.endswith(".json"): if path.endswith(".json"):
with fsspec.open(path, "w") as f: with fsspec.open(path, "w") as f:
json.dump(obj, f, indent=4) json.dump(obj, f, indent=4)
@ -39,20 +42,20 @@ class BaseIDManager:
It defines common `ID` manager specific functions. It defines common `ID` manager specific functions.
""" """
def __init__(self, id_file_path: str = ""): def __init__(self, id_file_path: Union[str, os.PathLike[Any]] = ""):
self.name_to_id = {} self.name_to_id = {}
if id_file_path: if id_file_path:
self.load_ids_from_file(id_file_path) self.load_ids_from_file(id_file_path)
@staticmethod @staticmethod
def _load_json(json_file_path: str) -> Dict: def _load_json(json_file_path: Union[str, os.PathLike[Any]]) -> Dict:
with fsspec.open(json_file_path, "r") as f: with fsspec.open(str(json_file_path), "r") as f:
return json.load(f) return json.load(f)
@staticmethod @staticmethod
def _save_json(json_file_path: str, data: dict) -> None: def _save_json(json_file_path: Union[str, os.PathLike[Any]], data: dict) -> None:
with fsspec.open(json_file_path, "w") as f: with fsspec.open(str(json_file_path), "w") as f:
json.dump(data, f, indent=4) json.dump(data, f, indent=4)
def set_ids_from_data(self, items: List, parse_key: str) -> None: def set_ids_from_data(self, items: List, parse_key: str) -> None:
@ -63,7 +66,7 @@ class BaseIDManager:
""" """
self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key) self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key)
def load_ids_from_file(self, file_path: str) -> None: def load_ids_from_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Set IDs from a file. """Set IDs from a file.
Args: Args:
@ -71,7 +74,7 @@ class BaseIDManager:
""" """
self.name_to_id = load_file(file_path) self.name_to_id = load_file(file_path)
def save_ids_to_file(self, file_path: str) -> None: def save_ids_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Save IDs to a json file. """Save IDs to a json file.
Args: Args:
@ -130,10 +133,10 @@ class EmbeddingManager(BaseIDManager):
def __init__( def __init__(
self, self,
embedding_file_path: Union[str, List[str]] = "", embedding_file_path: Union[Union[str, os.PathLike[Any]], list[Union[str, os.PathLike[Any]]]] = "",
id_file_path: str = "", id_file_path: Union[str, os.PathLike[Any]] = "",
encoder_model_path: str = "", encoder_model_path: Union[str, os.PathLike[Any]] = "",
encoder_config_path: str = "", encoder_config_path: Union[str, os.PathLike[Any]] = "",
use_cuda: bool = False, use_cuda: bool = False,
): ):
super().__init__(id_file_path=id_file_path) super().__init__(id_file_path=id_file_path)
@ -176,7 +179,7 @@ class EmbeddingManager(BaseIDManager):
"""Get embedding names.""" """Get embedding names."""
return list(self.embeddings_by_names.keys()) return list(self.embeddings_by_names.keys())
def save_embeddings_to_file(self, file_path: str) -> None: def save_embeddings_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Save embeddings to a json file. """Save embeddings to a json file.
Args: Args:
@ -185,7 +188,7 @@ class EmbeddingManager(BaseIDManager):
save_file(self.embeddings, file_path) save_file(self.embeddings, file_path)
@staticmethod @staticmethod
def read_embeddings_from_file(file_path: str): def read_embeddings_from_file(file_path: Union[str, os.PathLike[Any]]):
"""Load embeddings from a json file. """Load embeddings from a json file.
Args: Args:
@ -204,7 +207,7 @@ class EmbeddingManager(BaseIDManager):
embeddings_by_names[x["name"]].append(x["embedding"]) embeddings_by_names[x["name"]].append(x["embedding"])
return name_to_id, clip_ids, embeddings, embeddings_by_names return name_to_id, clip_ids, embeddings, embeddings_by_names
def load_embeddings_from_file(self, file_path: str) -> None: def load_embeddings_from_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Load embeddings from a json file. """Load embeddings from a json file.
Args: Args:
@ -214,7 +217,7 @@ class EmbeddingManager(BaseIDManager):
file_path file_path
) )
def load_embeddings_from_list_of_files(self, file_paths: List[str]) -> None: def load_embeddings_from_list_of_files(self, file_paths: list[Union[str, os.PathLike[Any]]]) -> None:
"""Load embeddings from a list of json files and don't allow duplicate keys. """Load embeddings from a list of json files and don't allow duplicate keys.
Args: Args:
@ -313,7 +316,9 @@ class EmbeddingManager(BaseIDManager):
def get_clips(self) -> List: def get_clips(self) -> List:
return sorted(self.embeddings.keys()) return sorted(self.embeddings.keys())
def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> None: def init_encoder(
self, model_path: Union[str, os.PathLike[Any]], config_path: Union[str, os.PathLike[Any]], use_cuda=False
) -> None:
"""Initialize a speaker encoder model. """Initialize a speaker encoder model.
Args: Args:
@ -325,11 +330,13 @@ class EmbeddingManager(BaseIDManager):
self.encoder_config = load_config(config_path) self.encoder_config = load_config(config_path)
self.encoder = setup_encoder_model(self.encoder_config) self.encoder = setup_encoder_model(self.encoder_config)
self.encoder_criterion = self.encoder.load_checkpoint( self.encoder_criterion = self.encoder.load_checkpoint(
self.encoder_config, model_path, eval=True, use_cuda=use_cuda, cache=True self.encoder_config, str(model_path), eval=True, use_cuda=use_cuda, cache=True
) )
self.encoder_ap = AudioProcessor(**self.encoder_config.audio) self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list: def compute_embedding_from_clip(
self, wav_file: Union[Union[str, os.PathLike[Any]], List[Union[str, os.PathLike[Any]]]]
) -> list:
"""Compute a embedding from a given audio file. """Compute a embedding from a given audio file.
Args: Args:

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
import os import os
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Optional, Union
import fsspec import fsspec
import numpy as np import numpy as np
@ -56,11 +56,11 @@ class SpeakerManager(EmbeddingManager):
def __init__( def __init__(
self, self,
data_items: List[List[Any]] = None, data_items: Optional[list[list[Any]]] = None,
d_vectors_file_path: str = "", d_vectors_file_path: str = "",
speaker_id_file_path: str = "", speaker_id_file_path: Union[str, os.PathLike[Any]] = "",
encoder_model_path: str = "", encoder_model_path: Union[str, os.PathLike[Any]] = "",
encoder_config_path: str = "", encoder_config_path: Union[str, os.PathLike[Any]] = "",
use_cuda: bool = False, use_cuda: bool = False,
): ):
super().__init__( super().__init__(

View File

@ -1,6 +1,7 @@
import logging import logging
import os
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Any, Optional, Union
import librosa import librosa
import numpy as np import numpy as np
@ -406,7 +407,9 @@ def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.n
return rms_norm(wav=x, db_level=db_level) return rms_norm(wav=x, db_level=db_level)
def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool = False, **kwargs) -> np.ndarray: def load_wav(
*, filename: Union[str, os.PathLike[Any]], sample_rate: Optional[int] = None, resample: bool = False, **kwargs
) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize. """Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
@ -434,7 +437,7 @@ def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool
def save_wav( def save_wav(
*, *,
wav: np.ndarray, wav: np.ndarray,
path: str, path: Union[str, os.PathLike[Any]],
sample_rate: int, sample_rate: int,
pipe_out=None, pipe_out=None,
do_rms_norm: bool = False, do_rms_norm: bool = False,

View File

@ -1,5 +1,6 @@
import logging import logging
from typing import Optional import os
from typing import Any, Optional, Union
import librosa import librosa
import numpy as np import numpy as np
@ -548,7 +549,7 @@ class AudioProcessor:
return volume_norm(x=x) return volume_norm(x=x)
### save and load ### ### save and load ###
def load_wav(self, filename: str, sr: Optional[int] = None) -> np.ndarray: def load_wav(self, filename: Union[str, os.PathLike[Any]], sr: Optional[int] = None) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize. """Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
@ -575,7 +576,9 @@ class AudioProcessor:
x = rms_volume_norm(x=x, db_level=self.db_level) x = rms_volume_norm(x=x, db_level=self.db_level)
return x return x
def save_wav(self, wav: np.ndarray, path: str, sr: Optional[int] = None, pipe_out=None) -> None: def save_wav(
self, wav: np.ndarray, path: Union[str, os.PathLike[Any]], sr: Optional[int] = None, pipe_out=None
) -> None:
"""Save a waveform to a file using Scipy. """Save a waveform to a file using Scipy.
Args: Args:

View File

@ -4,7 +4,7 @@ import importlib
import logging import logging
import re import re
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Optional, TypeVar, Union from typing import Any, Callable, Dict, Optional, TypeVar, Union
import torch import torch
from packaging.version import Version from packaging.version import Version
@ -133,3 +133,8 @@ def setup_logger(
def is_pytorch_at_least_2_4() -> bool: def is_pytorch_at_least_2_4() -> bool:
"""Check if the installed Pytorch version is 2.4 or higher.""" """Check if the installed Pytorch version is 2.4 or higher."""
return Version(torch.__version__) >= Version("2.4") return Version(torch.__version__) >= Version("2.4")
def optional_to_str(x: Optional[Any]) -> str:
"""Convert input to string, using empty string if input is None."""
return "" if x is None else str(x)

View File

@ -6,17 +6,35 @@ import tarfile
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from typing import Dict, Tuple from typing import Any, Optional, TypedDict, Union
import fsspec import fsspec
import requests import requests
from tqdm import tqdm from tqdm import tqdm
from trainer.io import get_user_data_dir from trainer.io import get_user_data_dir
from typing_extensions import Required
from TTS.config import load_config, read_json_with_comments from TTS.config import load_config, read_json_with_comments
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ModelItem(TypedDict, total=False):
model_name: Required[str]
model_type: Required[str]
description: str
license: str
author: str
contact: str
commit: Optional[str]
model_hash: str
tos_required: bool
default_vocoder: Optional[str]
model_url: Union[str, list[str]]
github_rls_url: Union[str, list[str]]
hf_url: list[str]
LICENSE_URLS = { LICENSE_URLS = {
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/", "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/", "mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
@ -40,19 +58,24 @@ class ModelManager(object):
home path. home path.
Args: Args:
models_file (str): path to .model.json file. Defaults to None. models_file (str or Path): path to .model.json file. Defaults to None.
output_prefix (str): prefix to `tts` to download models. Defaults to None output_prefix (str or Path): prefix to `tts` to download models. Defaults to None
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False. progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
""" """
def __init__(self, models_file=None, output_prefix=None, progress_bar=False): def __init__(
self,
models_file: Optional[Union[str, os.PathLike[Any]]] = None,
output_prefix: Optional[Union[str, os.PathLike[Any]]] = None,
progress_bar: bool = False,
) -> None:
super().__init__() super().__init__()
self.progress_bar = progress_bar self.progress_bar = progress_bar
if output_prefix is None: if output_prefix is None:
self.output_prefix = get_user_data_dir("tts") self.output_prefix = get_user_data_dir("tts")
else: else:
self.output_prefix = os.path.join(output_prefix, "tts") self.output_prefix = Path(output_prefix) / "tts"
self.models_dict = None self.models_dict = {}
if models_file is not None: if models_file is not None:
self.read_models_file(models_file) self.read_models_file(models_file)
else: else:
@ -60,7 +83,7 @@ class ModelManager(object):
path = Path(__file__).parent / "../.models.json" path = Path(__file__).parent / "../.models.json"
self.read_models_file(path) self.read_models_file(path)
def read_models_file(self, file_path): def read_models_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Read .models.json as a dict """Read .models.json as a dict
Args: Args:
@ -68,7 +91,7 @@ class ModelManager(object):
""" """
self.models_dict = read_json_with_comments(file_path) self.models_dict = read_json_with_comments(file_path)
def _list_models(self, model_type, model_count=0): def _list_models(self, model_type: str, model_count: int = 0) -> list[str]:
logger.info("") logger.info("")
logger.info("Name format: type/language/dataset/model") logger.info("Name format: type/language/dataset/model")
model_list = [] model_list = []
@ -83,13 +106,13 @@ class ModelManager(object):
model_count += 1 model_count += 1
return model_list return model_list
def _list_for_model_type(self, model_type): def _list_for_model_type(self, model_type: str) -> list[str]:
models_name_list = [] models_name_list = []
model_count = 1 model_count = 1
models_name_list.extend(self._list_models(model_type, model_count)) models_name_list.extend(self._list_models(model_type, model_count))
return models_name_list return models_name_list
def list_models(self): def list_models(self) -> list[str]:
models_name_list = [] models_name_list = []
model_count = 1 model_count = 1
for model_type in self.models_dict: for model_type in self.models_dict:
@ -97,7 +120,7 @@ class ModelManager(object):
models_name_list.extend(model_list) models_name_list.extend(model_list)
return models_name_list return models_name_list
def log_model_details(self, model_type, lang, dataset, model): def log_model_details(self, model_type: str, lang: str, dataset: str, model: str) -> None:
logger.info("Model type: %s", model_type) logger.info("Model type: %s", model_type)
logger.info("Language supported: %s", lang) logger.info("Language supported: %s", lang)
logger.info("Dataset used: %s", dataset) logger.info("Dataset used: %s", dataset)
@ -112,7 +135,7 @@ class ModelManager(object):
self.models_dict[model_type][lang][dataset][model]["default_vocoder"], self.models_dict[model_type][lang][dataset][model]["default_vocoder"],
) )
def model_info_by_idx(self, model_query): def model_info_by_idx(self, model_query: str) -> None:
"""Print the description of the model from .models.json file using model_query_idx """Print the description of the model from .models.json file using model_query_idx
Args: Args:
@ -144,7 +167,7 @@ class ModelManager(object):
model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/") model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
self.log_model_details(model_type, lang, dataset, model) self.log_model_details(model_type, lang, dataset, model)
def model_info_by_full_name(self, model_query_name): def model_info_by_full_name(self, model_query_name: str) -> None:
"""Print the description of the model from .models.json file using model_full_name """Print the description of the model from .models.json file using model_full_name
Args: Args:
@ -165,35 +188,35 @@ class ModelManager(object):
return return
self.log_model_details(model_type, lang, dataset, model) self.log_model_details(model_type, lang, dataset, model)
def list_tts_models(self): def list_tts_models(self) -> list[str]:
"""Print all `TTS` models and return a list of model names """Print all `TTS` models and return a list of model names
Format is `language/dataset/model` Format is `language/dataset/model`
""" """
return self._list_for_model_type("tts_models") return self._list_for_model_type("tts_models")
def list_vocoder_models(self): def list_vocoder_models(self) -> list[str]:
"""Print all the `vocoder` models and return a list of model names """Print all the `vocoder` models and return a list of model names
Format is `language/dataset/model` Format is `language/dataset/model`
""" """
return self._list_for_model_type("vocoder_models") return self._list_for_model_type("vocoder_models")
def list_vc_models(self): def list_vc_models(self) -> list[str]:
"""Print all the voice conversion models and return a list of model names """Print all the voice conversion models and return a list of model names
Format is `language/dataset/model` Format is `language/dataset/model`
""" """
return self._list_for_model_type("voice_conversion_models") return self._list_for_model_type("voice_conversion_models")
def list_langs(self): def list_langs(self) -> None:
"""Print all the available languages""" """Print all the available languages"""
logger.info("Name format: type/language") logger.info("Name format: type/language")
for model_type in self.models_dict: for model_type in self.models_dict:
for lang in self.models_dict[model_type]: for lang in self.models_dict[model_type]:
logger.info(" %s/%s", model_type, lang) logger.info(" %s/%s", model_type, lang)
def list_datasets(self): def list_datasets(self) -> None:
"""Print all the datasets""" """Print all the datasets"""
logger.info("Name format: type/language/dataset") logger.info("Name format: type/language/dataset")
for model_type in self.models_dict: for model_type in self.models_dict:
@ -202,7 +225,7 @@ class ModelManager(object):
logger.info(" %s/%s/%s", model_type, lang, dataset) logger.info(" %s/%s/%s", model_type, lang, dataset)
@staticmethod @staticmethod
def print_model_license(model_item: Dict): def print_model_license(model_item: ModelItem) -> None:
"""Print the license of a model """Print the license of a model
Args: Args:
@ -217,27 +240,27 @@ class ModelManager(object):
else: else:
logger.info("Model's license - No license information available") logger.info("Model's license - No license information available")
def _download_github_model(self, model_item: Dict, output_path: str): def _download_github_model(self, model_item: ModelItem, output_path: Path) -> None:
if isinstance(model_item["github_rls_url"], list): if isinstance(model_item["github_rls_url"], list):
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
else: else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
def _download_hf_model(self, model_item: Dict, output_path: str): def _download_hf_model(self, model_item: ModelItem, output_path: Path) -> None:
if isinstance(model_item["hf_url"], list): if isinstance(model_item["hf_url"], list):
self._download_model_files(model_item["hf_url"], output_path, self.progress_bar) self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
else: else:
self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar) self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)
def download_fairseq_model(self, model_name, output_path): def download_fairseq_model(self, model_name: str, output_path: Path) -> None:
URI_PREFIX = "https://dl.fbaipublicfiles.com/mms/tts/" URI_PREFIX = "https://dl.fbaipublicfiles.com/mms/tts/"
_, lang, _, _ = model_name.split("/") _, lang, _, _ = model_name.split("/")
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz") model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
self._download_tar_file(model_download_uri, output_path, self.progress_bar) self._download_tar_file(model_download_uri, output_path, self.progress_bar)
@staticmethod @staticmethod
def set_model_url(model_item: Dict): def set_model_url(model_item: ModelItem) -> ModelItem:
model_item["model_url"] = None model_item["model_url"] = ""
if "github_rls_url" in model_item: if "github_rls_url" in model_item:
model_item["model_url"] = model_item["github_rls_url"] model_item["model_url"] = model_item["github_rls_url"]
elif "hf_url" in model_item: elif "hf_url" in model_item:
@ -248,18 +271,18 @@ class ModelManager(object):
model_item["model_url"] = "https://huggingface.co/coqui/" model_item["model_url"] = "https://huggingface.co/coqui/"
return model_item return model_item
def _set_model_item(self, model_name): def _set_model_item(self, model_name: str) -> tuple[ModelItem, str, str, Optional[str]]:
# fetch model info from the dict # fetch model info from the dict
if "fairseq" in model_name: if "fairseq" in model_name:
model_type, lang, dataset, model = model_name.split("/") model_type, lang, dataset, model = model_name.split("/")
model_item = { model_item: ModelItem = {
"model_name": model_name,
"model_type": "tts_models", "model_type": "tts_models",
"license": "CC BY-NC 4.0", "license": "CC BY-NC 4.0",
"default_vocoder": None, "default_vocoder": None,
"author": "fairseq", "author": "fairseq",
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.", "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
} }
model_item["model_name"] = model_name
elif "xtts" in model_name and len(model_name.split("/")) != 4: elif "xtts" in model_name and len(model_name.split("/")) != 4:
# loading xtts models with only model name (e.g. xtts_v2.0.2) # loading xtts models with only model name (e.g. xtts_v2.0.2)
# check model name has the version number with regex # check model name has the version number with regex
@ -273,6 +296,8 @@ class ModelManager(object):
dataset = "multi-dataset" dataset = "multi-dataset"
model = model_name model = model_name
model_item = { model_item = {
"model_name": model_name,
"model_type": model_type,
"default_vocoder": None, "default_vocoder": None,
"license": "CPML", "license": "CPML",
"contact": "info@coqui.ai", "contact": "info@coqui.ai",
@ -297,9 +322,9 @@ class ModelManager(object):
return model_item, model_full_name, model, md5hash return model_item, model_full_name, model, md5hash
@staticmethod @staticmethod
def ask_tos(model_full_path): def ask_tos(model_full_path: Path) -> bool:
"""Ask the user to agree to the terms of service""" """Ask the user to agree to the terms of service"""
tos_path = os.path.join(model_full_path, "tos_agreed.txt") tos_path = model_full_path / "tos_agreed.txt"
print(" > You must confirm the following:") print(" > You must confirm the following:")
print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"') print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"')
print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]') print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]')
@ -311,7 +336,7 @@ class ModelManager(object):
return False return False
@staticmethod @staticmethod
def tos_agreed(model_item, model_full_path): def tos_agreed(model_item: ModelItem, model_full_path: Path) -> bool:
"""Check if the user has agreed to the terms of service""" """Check if the user has agreed to the terms of service"""
if "tos_required" in model_item and model_item["tos_required"]: if "tos_required" in model_item and model_item["tos_required"]:
tos_path = os.path.join(model_full_path, "tos_agreed.txt") tos_path = os.path.join(model_full_path, "tos_agreed.txt")
@ -320,12 +345,12 @@ class ModelManager(object):
return False return False
return True return True
def create_dir_and_download_model(self, model_name, model_item, output_path): def create_dir_and_download_model(self, model_name: str, model_item: ModelItem, output_path: Path) -> None:
os.makedirs(output_path, exist_ok=True) output_path.mkdir(exist_ok=True, parents=True)
# handle TOS # handle TOS
if not self.tos_agreed(model_item, output_path): if not self.tos_agreed(model_item, output_path):
if not self.ask_tos(output_path): if not self.ask_tos(output_path):
os.rmdir(output_path) output_path.rmdir()
raise Exception(" [!] You must agree to the terms of service to use this model.") raise Exception(" [!] You must agree to the terms of service to use this model.")
logger.info("Downloading model to %s", output_path) logger.info("Downloading model to %s", output_path)
try: try:
@ -342,7 +367,7 @@ class ModelManager(object):
raise e raise e
self.print_model_license(model_item=model_item) self.print_model_license(model_item=model_item)
def check_if_configs_are_equal(self, model_name, model_item, output_path): def check_if_configs_are_equal(self, model_name: str, model_item: ModelItem, output_path: Path) -> None:
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f: with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
config_local = json.load(f) config_local = json.load(f)
remote_url = None remote_url = None
@ -358,7 +383,7 @@ class ModelManager(object):
logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name) logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path) self.create_dir_and_download_model(model_name, model_item, output_path)
def download_model(self, model_name): def download_model(self, model_name: str) -> tuple[Path, Optional[Path], ModelItem]:
"""Download model files given the full model name. """Download model files given the full model name.
Model name is in the format Model name is in the format
'type/language/dataset/model' 'type/language/dataset/model'
@ -374,12 +399,12 @@ class ModelManager(object):
""" """
model_item, model_full_name, model, md5sum = self._set_model_item(model_name) model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
# set the model specific output path # set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name) output_path = Path(self.output_prefix) / model_full_name
if os.path.exists(output_path): if output_path.is_dir():
if md5sum is not None: if md5sum is not None:
md5sum_file = os.path.join(output_path, "hash.md5") md5sum_file = output_path / "hash.md5"
if os.path.isfile(md5sum_file): if md5sum_file.is_file():
with open(md5sum_file, mode="r") as f: with md5sum_file.open() as f:
if not f.read() == md5sum: if not f.read() == md5sum:
logger.info("%s has been updated, clearing model cache...", model_name) logger.info("%s has been updated, clearing model cache...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path) self.create_dir_and_download_model(model_name, model_item, output_path)
@ -407,12 +432,14 @@ class ModelManager(object):
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
): # TODO:This is stupid but don't care for now. ): # TODO:This is stupid but don't care for now.
output_model_path, output_config_path = self._find_files(output_path) output_model_path, output_config_path = self._find_files(output_path)
else:
output_config_path = output_model_path / "config.json"
# update paths in the config.json # update paths in the config.json
self._update_paths(output_path, output_config_path) self._update_paths(output_path, output_config_path)
return output_model_path, output_config_path, model_item return output_model_path, output_config_path, model_item
@staticmethod @staticmethod
def _find_files(output_path: str) -> Tuple[str, str]: def _find_files(output_path: Path) -> tuple[Path, Path]:
"""Find the model and config files in the output path """Find the model and config files in the output path
Args: Args:
@ -423,11 +450,11 @@ class ModelManager(object):
""" """
model_file = None model_file = None
config_file = None config_file = None
for file_name in os.listdir(output_path): for f in output_path.iterdir():
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]: if f.name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]:
model_file = os.path.join(output_path, file_name) model_file = f
elif file_name == "config.json": elif f.name == "config.json":
config_file = os.path.join(output_path, file_name) config_file = f
if model_file is None: if model_file is None:
raise ValueError(" [!] Model file not found in the output path") raise ValueError(" [!] Model file not found in the output path")
if config_file is None: if config_file is None:
@ -435,7 +462,7 @@ class ModelManager(object):
return model_file, config_file return model_file, config_file
@staticmethod @staticmethod
def _find_speaker_encoder(output_path: str) -> str: def _find_speaker_encoder(output_path: Path) -> Optional[Path]:
"""Find the speaker encoder file in the output path """Find the speaker encoder file in the output path
Args: Args:
@ -445,24 +472,24 @@ class ModelManager(object):
str: path to the speaker encoder file str: path to the speaker encoder file
""" """
speaker_encoder_file = None speaker_encoder_file = None
for file_name in os.listdir(output_path): for f in output_path.iterdir():
if file_name in ["model_se.pth", "model_se.pth.tar"]: if f.name in ["model_se.pth", "model_se.pth.tar"]:
speaker_encoder_file = os.path.join(output_path, file_name) speaker_encoder_file = f
return speaker_encoder_file return speaker_encoder_file
def _update_paths(self, output_path: str, config_path: str) -> None: def _update_paths(self, output_path: Path, config_path: Path) -> None:
"""Update paths for certain files in config.json after download. """Update paths for certain files in config.json after download.
Args: Args:
output_path (str): local path the model is downloaded to. output_path (str): local path the model is downloaded to.
config_path (str): local config.json path. config_path (str): local config.json path.
""" """
output_stats_path = os.path.join(output_path, "scale_stats.npy") output_stats_path = output_path / "scale_stats.npy"
output_d_vector_file_path = os.path.join(output_path, "speakers.json") output_d_vector_file_path = output_path / "speakers.json"
output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth") output_d_vector_file_pth_path = output_path / "speakers.pth"
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") output_speaker_ids_file_path = output_path / "speaker_ids.json"
output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth") output_speaker_ids_file_pth_path = output_path / "speaker_ids.pth"
speaker_encoder_config_path = os.path.join(output_path, "config_se.json") speaker_encoder_config_path = output_path / "config_se.json"
speaker_encoder_model_path = self._find_speaker_encoder(output_path) speaker_encoder_model_path = self._find_speaker_encoder(output_path)
# update the scale_path.npy file path in the model config.json # update the scale_path.npy file path in the model config.json
@ -487,10 +514,10 @@ class ModelManager(object):
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path) self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
@staticmethod @staticmethod
def _update_path(field_name, new_path, config_path): def _update_path(field_name: str, new_path: Optional[Path], config_path: Path) -> None:
"""Update the path in the model config.json for the current environment after download""" """Update the path in the model config.json for the current environment after download"""
if new_path and os.path.exists(new_path): if new_path is not None and new_path.is_file():
config = load_config(config_path) config = load_config(str(config_path))
field_names = field_name.split(".") field_names = field_name.split(".")
if len(field_names) > 1: if len(field_names) > 1:
# field name points to a sub-level field # field name points to a sub-level field
@ -515,7 +542,7 @@ class ModelManager(object):
config.save_json(config_path) config.save_json(config_path)
@staticmethod @staticmethod
def _download_zip_file(file_url, output_folder, progress_bar): def _download_zip_file(file_url: str, output_folder: Path, progress_bar: bool) -> None:
"""Download the github releases""" """Download the github releases"""
# download the file # download the file
r = requests.get(file_url, stream=True) r = requests.get(file_url, stream=True)
@ -525,7 +552,7 @@ class ModelManager(object):
block_size = 1024 # 1 Kibibyte block_size = 1024 # 1 Kibibyte
if progress_bar: if progress_bar:
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) temp_zip_name = output_folder / file_url.split("/")[-1]
with open(temp_zip_name, "wb") as file: with open(temp_zip_name, "wb") as file:
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
if progress_bar: if progress_bar:
@ -533,24 +560,24 @@ class ModelManager(object):
file.write(data) file.write(data)
with zipfile.ZipFile(temp_zip_name) as z: with zipfile.ZipFile(temp_zip_name) as z:
z.extractall(output_folder) z.extractall(output_folder)
os.remove(temp_zip_name) # delete zip after extract temp_zip_name.unlink() # delete zip after extract
except zipfile.BadZipFile: except zipfile.BadZipFile:
logger.exception("Bad zip file - %s", file_url) logger.exception("Bad zip file - %s", file_url)
raise zipfile.BadZipFile # pylint: disable=raise-missing-from raise zipfile.BadZipFile # pylint: disable=raise-missing-from
# move the files to the outer path # move the files to the outer path
for file_path in z.namelist(): for file_path in z.namelist():
src_path = os.path.join(output_folder, file_path) src_path = output_folder / file_path
if os.path.isfile(src_path): if src_path.is_file():
dst_path = os.path.join(output_folder, os.path.basename(file_path)) dst_path = output_folder / os.path.basename(file_path)
if src_path != dst_path: if src_path != dst_path:
copyfile(src_path, dst_path) copyfile(src_path, dst_path)
# remove redundant (hidden or not) folders # remove redundant (hidden or not) folders
for file_path in z.namelist(): for file_path in z.namelist():
if os.path.isdir(os.path.join(output_folder, file_path)): if (output_folder / file_path).is_dir():
rmtree(os.path.join(output_folder, file_path)) rmtree(output_folder / file_path)
@staticmethod @staticmethod
def _download_tar_file(file_url, output_folder, progress_bar): def _download_tar_file(file_url: str, output_folder: Path, progress_bar: bool) -> None:
"""Download the github releases""" """Download the github releases"""
# download the file # download the file
r = requests.get(file_url, stream=True) r = requests.get(file_url, stream=True)
@ -560,7 +587,7 @@ class ModelManager(object):
block_size = 1024 # 1 Kibibyte block_size = 1024 # 1 Kibibyte
if progress_bar: if progress_bar:
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1]) temp_tar_name = output_folder / file_url.split("/")[-1]
with open(temp_tar_name, "wb") as file: with open(temp_tar_name, "wb") as file:
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
if progress_bar: if progress_bar:
@ -569,43 +596,37 @@ class ModelManager(object):
with tarfile.open(temp_tar_name) as t: with tarfile.open(temp_tar_name) as t:
t.extractall(output_folder) t.extractall(output_folder)
tar_names = t.getnames() tar_names = t.getnames()
os.remove(temp_tar_name) # delete tar after extract temp_tar_name.unlink() # delete tar after extract
except tarfile.ReadError: except tarfile.ReadError:
logger.exception("Bad tar file - %s", file_url) logger.exception("Bad tar file - %s", file_url)
raise tarfile.ReadError # pylint: disable=raise-missing-from raise tarfile.ReadError # pylint: disable=raise-missing-from
# move the files to the outer path # move the files to the outer path
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])): for file_path in (output_folder / tar_names[0]).iterdir():
src_path = os.path.join(output_folder, tar_names[0], file_path) src_path = file_path
dst_path = os.path.join(output_folder, os.path.basename(file_path)) dst_path = output_folder / file_path.name
if src_path != dst_path: if src_path != dst_path:
copyfile(src_path, dst_path) copyfile(src_path, dst_path)
# remove the extracted folder # remove the extracted folder
rmtree(os.path.join(output_folder, tar_names[0])) rmtree(output_folder / tar_names[0])
@staticmethod @staticmethod
def _download_model_files(file_urls, output_folder, progress_bar): def _download_model_files(
file_urls: list[str], output_folder: Union[str, os.PathLike[Any]], progress_bar: bool
) -> None:
"""Download the github releases""" """Download the github releases"""
output_folder = Path(output_folder)
for file_url in file_urls: for file_url in file_urls:
# download the file # download the file
r = requests.get(file_url, stream=True) r = requests.get(file_url, stream=True)
# extract the file # extract the file
bease_filename = file_url.split("/")[-1] base_filename = file_url.split("/")[-1]
temp_zip_name = os.path.join(output_folder, bease_filename) file_path = output_folder / base_filename
total_size_in_bytes = int(r.headers.get("content-length", 0)) total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte block_size = 1024 # 1 Kibibyte
with open(temp_zip_name, "wb") as file: with open(file_path, "wb") as f:
if progress_bar: if progress_bar:
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
if progress_bar: if progress_bar:
ModelManager.tqdm_progress.update(len(data)) ModelManager.tqdm_progress.update(len(data))
file.write(data) f.write(data)
@staticmethod
def _check_dict_key(my_dict, key):
if key in my_dict.keys() and my_dict[key] is not None:
if not isinstance(key, str):
return True
if isinstance(key, str) and len(my_dict[key]) > 0:
return True
return False

View File

@ -2,7 +2,7 @@ import logging
import os import os
import time import time
from pathlib import Path from pathlib import Path
from typing import List from typing import Any, List, Optional, Union
import numpy as np import numpy as np
import pysbd import pysbd
@ -16,6 +16,7 @@ from TTS.tts.models.vits import Vits
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import save_wav from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.generic_utils import optional_to_str
from TTS.vc.configs.openvoice_config import OpenVoiceConfig from TTS.vc.configs.openvoice_config import OpenVoiceConfig
from TTS.vc.models import setup_model as setup_vc_model from TTS.vc.models import setup_model as setup_vc_model
from TTS.vc.models.openvoice import OpenVoice from TTS.vc.models.openvoice import OpenVoice
@ -29,18 +30,18 @@ class Synthesizer(nn.Module):
def __init__( def __init__(
self, self,
*, *,
tts_checkpoint: str = "", tts_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
tts_config_path: str = "", tts_config_path: Optional[Union[str, os.PathLike[Any]]] = None,
tts_speakers_file: str = "", tts_speakers_file: Optional[Union[str, os.PathLike[Any]]] = None,
tts_languages_file: str = "", tts_languages_file: Optional[Union[str, os.PathLike[Any]]] = None,
vocoder_checkpoint: str = "", vocoder_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
vocoder_config: str = "", vocoder_config: Optional[Union[str, os.PathLike[Any]]] = None,
encoder_checkpoint: str = "", encoder_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
encoder_config: str = "", encoder_config: Optional[Union[str, os.PathLike[Any]]] = None,
vc_checkpoint: str = "", vc_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None,
vc_config: str = "", vc_config: Optional[Union[str, os.PathLike[Any]]] = None,
model_dir: str = "", model_dir: Optional[Union[str, os.PathLike[Any]]] = None,
voice_dir: str = None, voice_dir: Optional[Union[str, os.PathLike[Any]]] = None,
use_cuda: bool = False, use_cuda: bool = False,
) -> None: ) -> None:
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder """General 🐸 TTS interface for inference. It takes a tts and a vocoder
@ -66,16 +67,17 @@ class Synthesizer(nn.Module):
use_cuda (bool, optional): enable/disable cuda. Defaults to False. use_cuda (bool, optional): enable/disable cuda. Defaults to False.
""" """
super().__init__() super().__init__()
self.tts_checkpoint = tts_checkpoint self.tts_checkpoint = optional_to_str(tts_checkpoint)
self.tts_config_path = tts_config_path self.tts_config_path = optional_to_str(tts_config_path)
self.tts_speakers_file = tts_speakers_file self.tts_speakers_file = optional_to_str(tts_speakers_file)
self.tts_languages_file = tts_languages_file self.tts_languages_file = optional_to_str(tts_languages_file)
self.vocoder_checkpoint = vocoder_checkpoint self.vocoder_checkpoint = optional_to_str(vocoder_checkpoint)
self.vocoder_config = vocoder_config self.vocoder_config = optional_to_str(vocoder_config)
self.encoder_checkpoint = encoder_checkpoint self.encoder_checkpoint = optional_to_str(encoder_checkpoint)
self.encoder_config = encoder_config self.encoder_config = optional_to_str(encoder_config)
self.vc_checkpoint = vc_checkpoint self.vc_checkpoint = optional_to_str(vc_checkpoint)
self.vc_config = vc_config self.vc_config = optional_to_str(vc_config)
model_dir = optional_to_str(model_dir)
self.use_cuda = use_cuda self.use_cuda = use_cuda
self.tts_model = None self.tts_model = None
@ -89,18 +91,18 @@ class Synthesizer(nn.Module):
self.d_vector_dim = 0 self.d_vector_dim = 0
self.seg = self._get_segmenter("en") self.seg = self._get_segmenter("en")
self.use_cuda = use_cuda self.use_cuda = use_cuda
self.voice_dir = voice_dir self.voice_dir = optional_to_str(voice_dir)
if self.use_cuda: if self.use_cuda:
assert torch.cuda.is_available(), "CUDA is not availabe on this machine." assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
if tts_checkpoint: if tts_checkpoint:
self._load_tts(tts_checkpoint, tts_config_path, use_cuda) self._load_tts(self.tts_checkpoint, self.tts_config_path, use_cuda)
if vocoder_checkpoint: if vocoder_checkpoint:
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) self._load_vocoder(self.vocoder_checkpoint, self.vocoder_config, use_cuda)
if vc_checkpoint and model_dir is None: if vc_checkpoint and model_dir == "":
self._load_vc(vc_checkpoint, vc_config, use_cuda) self._load_vc(self.vc_checkpoint, self.vc_config, use_cuda)
if model_dir: if model_dir:
if "fairseq" in model_dir: if "fairseq" in model_dir:

View File

@ -1,5 +1,4 @@
#!/usr/bin/env python3` #!/usr/bin/env python3`
import glob
import os import os
import shutil import shutil
@ -30,22 +29,22 @@ def run_models(offset=0, step=1):
print(f"\n > Run - {model_name}") print(f"\n > Run - {model_name}")
model_path, _, _ = manager.download_model(model_name) model_path, _, _ = manager.download_model(model_name)
if "tts_models" in model_name: if "tts_models" in model_name:
local_download_dir = os.path.dirname(model_path) local_download_dir = model_path.parent
# download and run the model # download and run the model
speaker_files = glob.glob(local_download_dir + "/speaker*") speaker_files = list(local_download_dir.glob("speaker*"))
language_files = glob.glob(local_download_dir + "/language*") language_files = list(local_download_dir.glob("language*"))
speaker_arg = "" speaker_arg = ""
language_arg = "" language_arg = ""
if len(speaker_files) > 0: if len(speaker_files) > 0:
# multi-speaker model # multi-speaker model
if "speaker_ids" in speaker_files[0]: if "speaker_ids" in speaker_files[0].stem:
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0]) speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
elif "speakers" in speaker_files[0]: elif "speakers" in speaker_files[0].stem:
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0]) speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
speakers = list(speaker_manager.name_to_id.keys()) speakers = list(speaker_manager.name_to_id.keys())
if len(speakers) > 1: if len(speakers) > 1:
speaker_arg = f'--speaker_idx "{speakers[0]}"' speaker_arg = f'--speaker_idx "{speakers[0]}"'
if len(language_files) > 0 and "language_ids" in language_files[0]: if len(language_files) > 0 and "language_ids" in language_files[0].stem:
# multi-lingual model # multi-lingual model
language_manager = LanguageManager(language_ids_file_path=language_files[0]) language_manager = LanguageManager(language_ids_file_path=language_files[0])
languages = language_manager.language_names languages = language_manager.language_names