coqui-tts/TTS/utils/generic_utils.py

151 lines
4.3 KiB
Python

# -*- coding: utf-8 -*-
import datetime
import importlib
import logging
import os
import re
from pathlib import Path
from typing import Any, Callable, Dict, Optional, TextIO, TypeVar, Union
import torch
from packaging.version import Version
from typing_extensions import TypeIs
logger = logging.getLogger(__name__)
_T = TypeVar("_T")
def exists(val: Union[_T, None]) -> TypeIs[_T]:
return val is not None
def default(val: Union[_T, None], d: Union[_T, Callable[[], _T]]) -> _T:
if exists(val):
return val
return d() if callable(d) else d
def to_camel(text):
text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
text = text.replace("Tts", "TTS")
text = text.replace("vc", "VC")
return text
def find_module(module_path: str, module_name: str) -> object:
module_name = module_name.lower()
module = importlib.import_module(module_path + "." + module_name)
class_name = to_camel(module_name)
return getattr(module, class_name)
def import_class(module_path: str) -> object:
"""Import a class from a module path.
Args:
module_path (str): The module path of the class.
Returns:
object: The imported class.
"""
class_name = module_path.split(".")[-1]
module_path = ".".join(module_path.split(".")[:-1])
module = importlib.import_module(module_path)
return getattr(module, class_name)
def get_import_path(obj: object) -> str:
"""Get the import path of a class.
Args:
obj (object): The class object.
Returns:
str: The import path of the class.
"""
return ".".join([type(obj).__module__, type(obj).__name__])
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
"""Format kwargs to hande auxilary inputs to models.
Args:
def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`.
kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model.
Returns:
Dict: arguments with formatted auxilary inputs.
"""
kwargs = kwargs.copy()
for name in def_args:
if name not in kwargs or kwargs[name] is None:
kwargs[name] = def_args[name]
return kwargs
def get_timestamp() -> str:
return datetime.datetime.now().strftime("%y%m%d-%H%M%S")
class ConsoleFormatter(logging.Formatter):
"""Custom formatter that prints logging.INFO messages without the level name.
Source: https://stackoverflow.com/a/62488520
"""
def format(self, record):
if record.levelno == logging.INFO:
self._style._fmt = "%(message)s"
else:
self._style._fmt = "%(levelname)s: %(message)s"
return super().format(record)
def setup_logger(
logger_name: str,
level: int = logging.INFO,
*,
formatter: Optional[logging.Formatter] = None,
stream: Optional[TextIO] = None,
log_dir: Optional[Union[str, os.PathLike[Any]]] = None,
log_name: str = "log",
) -> None:
"""Set up a logger.
Args:
logger_name: Name of the logger to set up
level: Logging level
formatter: Formatter for the logger
stream: Add a StreamHandler for the given stream, e.g. sys.stderr or sys.stdout
log_dir: Folder to write the log file (no file created if None)
log_name: Prefix of the log file name
"""
lg = logging.getLogger(logger_name)
if formatter is None:
formatter = logging.Formatter(
"%(asctime)s.%(msecs)03d - %(levelname)-8s - %(name)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S"
)
lg.setLevel(level)
if log_dir is not None:
Path(log_dir).mkdir(exist_ok=True, parents=True)
log_file = Path(log_dir) / f"{log_name}_{get_timestamp()}.log"
fh = logging.FileHandler(log_file, mode="w")
fh.setFormatter(formatter)
lg.addHandler(fh)
if stream is not None:
sh = logging.StreamHandler(stream)
sh.setFormatter(formatter)
lg.addHandler(sh)
def is_pytorch_at_least_2_4() -> bool:
"""Check if the installed Pytorch version is 2.4 or higher."""
return Version(torch.__version__) >= Version("2.4")
def optional_to_str(x: Optional[Any]) -> str:
"""Convert input to string, using empty string if input is None."""
return "" if x is None else str(x)