import contextlib import functools import hashlib import logging import os import requests import torch import tqdm from TTS.tts.layers.bark.model import GPT, GPTConfig from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig if ( torch.cuda.is_available() and hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") and torch.cuda.is_bf16_supported() ): autocast = functools.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) else: @contextlib.contextmanager def autocast(): yield # hold models in global scope to lazy load logger = logging.getLogger(__name__) if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): logger.warning( "torch version does not support flash attention. You will get significantly faster" + " inference speed by upgrade torch to newest version / nightly." ) # def _string_md5(s): # m = hashlib.md5() # m.update(s.encode("utf-8")) # return m.hexdigest() def _md5(fname): hash_md5 = hashlib.md5() with open(fname, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() # def _get_ckpt_path(model_type, CACHE_DIR): # model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"]) # return os.path.join(CACHE_DIR, f"{model_name}.pt") # S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/" # def _parse_s3_filepath(s3_filepath): # bucket_name = re.search(S3_BUCKET_PATH_RE, s3_filepath).group(1) # rel_s3_filepath = re.sub(S3_BUCKET_PATH_RE, "", s3_filepath) # return bucket_name, rel_s3_filepath def _download(from_s3_path, to_local_path, CACHE_DIR): os.makedirs(CACHE_DIR, exist_ok=True) response = requests.get(from_s3_path, stream=True) total_size_in_bytes = int(response.headers.get("content-length", 0)) block_size = 1024 # 1 Kibibyte progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) with open(to_local_path, "wb") as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes not in [0, progress_bar.n]: raise ValueError("ERROR, something went wrong") class InferenceContext: def __init__(self, benchmark=False): # we can't expect inputs to be the same length, so disable benchmarking by default self._chosen_cudnn_benchmark = benchmark self._cudnn_benchmark = None def __enter__(self): self._cudnn_benchmark = torch.backends.cudnn.benchmark torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark def __exit__(self, exc_type, exc_value, exc_traceback): torch.backends.cudnn.benchmark = self._cudnn_benchmark if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @contextlib.contextmanager def inference_mode(): with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast(): yield def clear_cuda_cache(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() # def clean_models(model_key=None): # global models # model_keys = [model_key] if model_key is not None else models.keys() # for k in model_keys: # if k in models: # del models[k] # clear_cuda_cache() def load_model(ckpt_path, device, config, model_type="text"): logger.info(f"loading {model_type} model from {ckpt_path}...") if device == "cpu": logger.warning("No GPU being used. Careful, Inference might be extremely slow!") if model_type == "text": ConfigClass = GPTConfig ModelClass = GPT elif model_type == "coarse": ConfigClass = GPTConfig ModelClass = GPT elif model_type == "fine": ConfigClass = FineGPTConfig ModelClass = FineGPT else: raise NotImplementedError() if ( not config.USE_SMALLER_MODELS and os.path.exists(ckpt_path) and _md5(ckpt_path) != config.REMOTE_MODEL_PATHS[model_type]["checksum"] ): logger.warning(f"found outdated {model_type} model, removing...") os.remove(ckpt_path) if not os.path.exists(ckpt_path): logger.info(f"{model_type} model not found, downloading...") _download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR) checkpoint = torch.load(ckpt_path, map_location=device) # this is a hack model_args = checkpoint["model_args"] if "input_vocab_size" not in model_args: model_args["input_vocab_size"] = model_args["vocab_size"] model_args["output_vocab_size"] = model_args["vocab_size"] del model_args["vocab_size"] gptconf = ConfigClass(**checkpoint["model_args"]) if model_type == "text": config.semantic_config = gptconf elif model_type == "coarse": config.coarse_config = gptconf elif model_type == "fine": config.fine_config = gptconf model = ModelClass(gptconf) state_dict = checkpoint["model"] # fixup checkpoint unwanted_prefix = "_orig_mod." for k, _ in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) extra_keys = set(state_dict.keys()) - set(model.state_dict().keys()) extra_keys = set(k for k in extra_keys if not k.endswith(".attn.bias")) missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) missing_keys = set(k for k in missing_keys if not k.endswith(".attn.bias")) if len(extra_keys) != 0: raise ValueError(f"extra keys found: {extra_keys}") if len(missing_keys) != 0: raise ValueError(f"missing keys: {missing_keys}") model.load_state_dict(state_dict, strict=False) n_params = model.get_num_params() val_loss = checkpoint["best_val_loss"].item() logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") model.eval() model.to(device) del checkpoint, state_dict clear_cuda_cache() return model, config # def _load_codec_model(device): # model = EncodecModel.encodec_model_24khz() # model.set_target_bandwidth(6.0) # model.eval() # model.to(device) # clear_cuda_cache() # return model # def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"): # _load_model_f = functools.partial(_load_model, model_type=model_type) # if model_type not in ("text", "coarse", "fine"): # raise NotImplementedError() # global models # if torch.cuda.device_count() == 0 or not use_gpu: # device = "cpu" # else: # device = "cuda" # model_key = str(device) + f"__{model_type}" # if model_key not in models or force_reload: # if ckpt_path is None: # ckpt_path = _get_ckpt_path(model_type) # clean_models(model_key=model_key) # model = _load_model_f(ckpt_path, device) # models[model_key] = model # return models[model_key] # def load_codec_model(use_gpu=True, force_reload=False): # global models # if torch.cuda.device_count() == 0 or not use_gpu: # device = "cpu" # else: # device = "cuda" # model_key = str(device) + f"__codec" # if model_key not in models or force_reload: # clean_models(model_key=model_key) # model = _load_codec_model(device) # models[model_key] = model # return models[model_key] # def preload_models( # text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True, use_smaller_models=False # ): # global USE_SMALLER_MODELS # global REMOTE_MODEL_PATHS # if use_smaller_models: # USE_SMALLER_MODELS = True # logger.info("Using smaller models generation.py") # REMOTE_MODEL_PATHS = SMALL_REMOTE_MODEL_PATHS # _ = load_model(ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True) # _ = load_model(ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True) # _ = load_model(ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True) # _ = load_codec_model(use_gpu=use_gpu, force_reload=True)