From deebc0cc16888cdd643831abcba0026f10592337 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Tue, 23 May 2023 10:12:26 +0200 Subject: [PATCH 01/29] Add bark requirements --- requirements.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 57640b6f..2b725bc6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,4 +49,7 @@ bnunicodenormalizer==0.1.1 #deps for tortoise k_diffusion einops -transformers \ No newline at end of file +transformers + +#deps for bark +encodec \ No newline at end of file From f59da4dba5e5b26707862b9ab03e0e6a80408c9e Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 12 Jun 2023 14:32:39 +0200 Subject: [PATCH 02/29] Draft Bark implementation --- TTS/tts/configs/bark_config.py | 67 +++ TTS/tts/layers/bark/__init__.py | 0 TTS/tts/layers/bark/hubert/__init__.py | 0 TTS/tts/layers/bark/hubert/hubert_manager.py | 33 ++ TTS/tts/layers/bark/hubert/kmeans_hubert.py | 101 ++++ TTS/tts/layers/bark/hubert/tokenizer.py | 196 +++++++ TTS/tts/layers/bark/inference_funcs.py | 575 +++++++++++++++++++ TTS/tts/layers/bark/load_model.py | 254 ++++++++ TTS/tts/layers/bark/model.py | 232 ++++++++ TTS/tts/layers/bark/model_fine.py | 142 +++++ 10 files changed, 1600 insertions(+) create mode 100644 TTS/tts/configs/bark_config.py create mode 100644 TTS/tts/layers/bark/__init__.py create mode 100644 TTS/tts/layers/bark/hubert/__init__.py create mode 100644 TTS/tts/layers/bark/hubert/hubert_manager.py create mode 100644 TTS/tts/layers/bark/hubert/kmeans_hubert.py create mode 100644 TTS/tts/layers/bark/hubert/tokenizer.py create mode 100644 TTS/tts/layers/bark/inference_funcs.py create mode 100644 TTS/tts/layers/bark/load_model.py create mode 100644 TTS/tts/layers/bark/model.py create mode 100644 TTS/tts/layers/bark/model_fine.py diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py new file mode 100644 index 00000000..760776a8 --- /dev/null +++ b/TTS/tts/configs/bark_config.py @@ -0,0 +1,67 @@ +import os +from dataclasses import dataclass, field +from typing import Dict + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.layers.bark.model import GPTConfig +from TTS.tts.layers.bark.model_fine import FineGPTConfig +from TTS.utils.generic_utils import get_user_data_dir + + +@dataclass +class BarkConfig(BaseTTSConfig): + num_chars: int = 0 + semantic_config: GPTConfig = GPTConfig() + fine_config: FineGPTConfig = FineGPTConfig() + coarse_config: GPTConfig = GPTConfig() + CONTEXT_WINDOW_SIZE: int = 1024 + SEMANTIC_RATE_HZ: float = 49.9 + SEMANTIC_VOCAB_SIZE: int = 10_000 + CODEBOOK_SIZE: int = 1024 + N_COARSE_CODEBOOKS: int = 2 + N_FINE_CODEBOOKS: int = 8 + COARSE_RATE_HZ: int = 75 + SAMPLE_RATE: int = 24_000 + USE_SMALLER_MODELS: bool = False + + TEXT_ENCODING_OFFSET: int = 10_048 + SEMANTIC_PAD_TOKEN: int = 10_000 + TEXT_PAD_TOKEN: int = 129_595 + SEMANTIC_INFER_TOKEN: int = 129_599 + COARSE_SEMANTIC_PAD_TOKEN: int = 12_048 + COARSE_INFER_TOKEN: int = 12_050 + + REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" + REMOTE_MODEL_PATHS: Dict = None + LOCAL_MODEL_PATHS: Dict = None + SMALL_REMOTE_MODEL_PATHS: Dict = None + CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0")) + + def __post_init__(self): + self.REMOTE_MODEL_PATHS = { + "text": { + "path": os.path.join(self.REMOTE_BASE_URL, "text_2.pt"), + "checksum": "54afa89d65e318d4f5f80e8e8799026a", + }, + "coarse": { + "path": os.path.join(self.REMOTE_BASE_URL, "coarse_2.pt"), + "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", + }, + "fine": { + "path": os.path.join(self.REMOTE_BASE_URL, "fine_2.pt"), + "checksum": "59d184ed44e3650774a2f0503a48a97b", + }, + } + self.LOCAL_MODEL_PATHS = { + "text": os.path.join(self.CACHE_DIR, "text_2.pt"), + "coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"), + "fine": os.path.join(self.CACHE_DIR, "fine_2.pt"), + "hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"), + "hubert": os.path.join(self.CACHE_DIR, "hubert.pt"), + } + self.SMALL_REMOTE_MODEL_PATHS = { + "text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")}, + "coarse": {"path": os.path.join(self.REMOTE_BASE_URL, "coarse.pt")}, + "fine": {"path": os.path.join(self.REMOTE_BASE_URL, "fine.pt")}, + } + self.sample_rate = self.SAMPLE_RATE diff --git a/TTS/tts/layers/bark/__init__.py b/TTS/tts/layers/bark/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/layers/bark/hubert/__init__.py b/TTS/tts/layers/bark/hubert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/layers/bark/hubert/hubert_manager.py b/TTS/tts/layers/bark/hubert/hubert_manager.py new file mode 100644 index 00000000..baa26438 --- /dev/null +++ b/TTS/tts/layers/bark/hubert/hubert_manager.py @@ -0,0 +1,33 @@ +# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer + +import os.path +import shutil +import urllib.request + +import huggingface_hub + + +class HubertManager: + @staticmethod + def make_sure_hubert_installed( + download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = "" + ): + if not os.path.isfile(model_path): + print("Downloading HuBERT base model") + urllib.request.urlretrieve(download_url, model_path) + print("Downloaded HuBERT") + return model_path + + @staticmethod + def make_sure_tokenizer_installed( + model: str = "quantifier_hubert_base_ls960_14.pth", + repo: str = "GitMylo/bark-voice-cloning", + model_path: str = "", + ): + model_dir = os.path.dirname(model_path) + if not os.path.isfile(model_path): + print("Downloading HuBERT custom tokenizer") + huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False) + shutil.move(os.path.join(model_dir, model), model_path) + print("Downloaded tokenizer") + return model_path diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py new file mode 100644 index 00000000..7c667755 --- /dev/null +++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -0,0 +1,101 @@ +""" +Modified HuBERT model without kmeans. +Original author: https://github.com/lucidrains/ +Modified by: https://www.github.com/gitmylo/ +License: MIT +""" + +# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py + +import logging +from pathlib import Path + +import fairseq +import torch +from einops import pack, unpack +from torch import nn +from torchaudio.functional import resample + +logging.root.setLevel(logging.ERROR) + + +def round_down_nearest_multiple(num, divisor): + return num // divisor * divisor + + +def curtail_to_multiple(t, mult, from_left=False): + data_len = t.shape[-1] + rounded_seq_len = round_down_nearest_multiple(data_len, mult) + seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None) + return t[..., seq_slice] + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +class CustomHubert(nn.Module): + """ + checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert + or you can train your own + """ + + def __init__(self, checkpoint_path, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None): + super().__init__() + self.target_sample_hz = target_sample_hz + self.seq_len_multiple_of = seq_len_multiple_of + self.output_layer = output_layer + + if device is not None: + self.to(device) + + model_path = Path(checkpoint_path) + + assert model_path.exists(), f"path {checkpoint_path} does not exist" + + checkpoint = torch.load(checkpoint_path) + load_model_input = {checkpoint_path: checkpoint} + model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) + + if device is not None: + model[0].to(device) + + self.model = model[0] + self.model.eval() + + @property + def groups(self): + return 1 + + @torch.no_grad() + def forward(self, wav_input, flatten=True, input_sample_hz=None): + device = wav_input.device + + if exists(input_sample_hz): + wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz) + + if exists(self.seq_len_multiple_of): + wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) + + embed = self.model( + wav_input, + features_only=True, + mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code + output_layer=self.output_layer, + ) + + embed, packed_shape = pack([embed["x"]], "* d") + + # codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy()) + + codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long() + + if flatten: + return codebook_indices + + (codebook_indices,) = unpack(codebook_indices, packed_shape, "*") + return codebook_indices diff --git a/TTS/tts/layers/bark/hubert/tokenizer.py b/TTS/tts/layers/bark/hubert/tokenizer.py new file mode 100644 index 00000000..474a08db --- /dev/null +++ b/TTS/tts/layers/bark/hubert/tokenizer.py @@ -0,0 +1,196 @@ +""" +Custom tokenizer model. +Author: https://www.github.com/gitmylo/ +License: MIT +""" + +import json +import os.path +from zipfile import ZipFile + +import numpy +import torch +from torch import nn, optim +from torch.serialization import MAP_LOCATION + + +class HubertTokenizer(nn.Module): + def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0): + super(HubertTokenizer, self).__init__() + next_size = input_size + if version == 0: + self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True) + next_size = hidden_size + if version == 1: + self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True) + self.intermediate = nn.Linear(hidden_size, 4096) + next_size = 4096 + + self.fc = nn.Linear(next_size, output_size) + self.softmax = nn.LogSoftmax(dim=1) + self.optimizer: optim.Optimizer = None + self.lossfunc = nn.CrossEntropyLoss() + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.version = version + + def forward(self, x): + x, _ = self.lstm(x) + if self.version == 1: + x = self.intermediate(x) + x = self.fc(x) + x = self.softmax(x) + return x + + @torch.no_grad() + def get_token(self, x): + """ + Used to get the token for the first + :param x: An array with shape (N, input_size) where N is a whole number greater or equal to 1, and input_size is the input size used when creating the model. + :return: An array with shape (N,) where N is the same as N from the input. Every number in the array is a whole number in range 0...output_size - 1 where output_size is the output size used when creating the model. + """ + return torch.argmax(self(x), dim=1) + + def prepare_training(self): + self.optimizer = optim.Adam(self.parameters(), 0.001) + + def train_step(self, x_train, y_train, log_loss=False): + # y_train = y_train[:-1] + # y_train = y_train[1:] + + optimizer = self.optimizer + lossfunc = self.lossfunc + # Zero the gradients + self.zero_grad() + + # Forward pass + y_pred = self(x_train) + + y_train_len = len(y_train) + y_pred_len = y_pred.shape[0] + + if y_train_len > y_pred_len: + diff = y_train_len - y_pred_len + y_train = y_train[diff:] + elif y_train_len < y_pred_len: + diff = y_pred_len - y_train_len + y_pred = y_pred[:-diff, :] + + y_train_hot = torch.zeros(len(y_train), self.output_size) + y_train_hot[range(len(y_train)), y_train] = 1 + y_train_hot = y_train_hot.to("cuda") + + # Calculate the loss + loss = lossfunc(y_pred, y_train_hot) + + # Print loss + if log_loss: + print("Loss", loss.item()) + + # Backward pass + loss.backward() + + # Update the weights + optimizer.step() + + def save(self, path): + info_path = ".".join(os.path.basename(path).split(".")[:-1]) + "/.info" + torch.save(self.state_dict(), path) + data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version) + with ZipFile(path, "a") as model_zip: + model_zip.writestr(info_path, data_from_model.save()) + model_zip.close() + + @staticmethod + def load_from_checkpoint(path, map_location: MAP_LOCATION = None): + old = True + with ZipFile(path) as model_zip: + filesMatch = [file for file in model_zip.namelist() if file.endswith("/.info")] + file = filesMatch[0] if filesMatch else None + if file: + old = False + data_from_model = Data.load(model_zip.read(file).decode("utf-8")) + model_zip.close() + if old: + model = HubertTokenizer() + else: + model = HubertTokenizer( + data_from_model.hidden_size, + data_from_model.input_size, + data_from_model.output_size, + data_from_model.version, + ) + model.load_state_dict(torch.load(path)) + if map_location: + model = model.to(map_location) + return model + + +class Data: + input_size: int + hidden_size: int + output_size: int + version: int + + def __init__(self, input_size=768, hidden_size=1024, output_size=10000, version=0): + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.version = version + + @staticmethod + def load(string): + data = json.loads(string) + return Data(data["input_size"], data["hidden_size"], data["output_size"], data["version"]) + + def save(self): + data = { + "input_size": self.input_size, + "hidden_size": self.hidden_size, + "output_size": self.output_size, + "version": self.version, + } + return json.dumps(data) + + +def auto_train(data_path, save_path="model.pth", load_model: str = None, save_epochs=1): + data_x, data_y = [], [] + + if load_model and os.path.isfile(load_model): + print("Loading model from", load_model) + model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda") + else: + print("Creating new model.") + model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm + save_path = os.path.join(data_path, save_path) + base_save_path = ".".join(save_path.split(".")[:-1]) + + sem_string = "_semantic.npy" + feat_string = "_semantic_features.npy" + + ready = os.path.join(data_path, "ready") + for input_file in os.listdir(ready): + full_path = os.path.join(ready, input_file) + if input_file.endswith(sem_string): + data_y.append(numpy.load(full_path)) + elif input_file.endswith(feat_string): + data_x.append(numpy.load(full_path)) + model_training.prepare_training() + + epoch = 1 + + while 1: + for i in range(save_epochs): + j = 0 + for x, y in zip(data_x, data_y): + model_training.train_step( + torch.tensor(x).to("cuda"), torch.tensor(y).to("cuda"), j % 50 == 0 + ) # Print loss every 50 steps + j += 1 + save_p = save_path + save_p_2 = f"{base_save_path}_epoch_{epoch}.pth" + model_training.save(save_p) + model_training.save(save_p_2) + print(f"Epoch {epoch} completed") + epoch += 1 diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py new file mode 100644 index 00000000..73c9ee71 --- /dev/null +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -0,0 +1,575 @@ +import logging +import os +import re +from glob import glob +from typing import Dict, List + +import librosa +import numpy as np +import torch +import torchaudio +import tqdm +from encodec.utils import convert_audio +from scipy.special import softmax +from torch.nn import functional as F + +from TTS.tts.layers.bark.hubert.hubert_manager import HubertManager +from TTS.tts.layers.bark.hubert.kmeans_hubert import CustomHubert +from TTS.tts.layers.bark.hubert.tokenizer import HubertTokenizer +from TTS.tts.layers.bark.load_model import _clear_cuda_cache, _inference_mode + +logger = logging.getLogger(__name__) + + +def _tokenize(tokenizer, text): + return tokenizer.encode(text, add_special_tokens=False) + + +def _detokenize(tokenizer, enc_text): + return tokenizer.decode(enc_text) + + +def _normalize_whitespace(text): + return re.sub(r"\s+", " ", text).strip() + + +def get_voices(extra_voice_dirs: List[str] = []): + voices = {} + for dir in extra_voice_dirs: + paths = list(glob(f"{dir}/*.npz")) + for path in paths: + name = os.path.basename(path).replace(".npz", "") + voices[name] = path + return voices + + +def load_voice(voice: str, extra_voice_dirs: List[str] = []): + def load_npz(npz_file): + x_history = np.load(npz_file) + semantic = x_history["semantic_prompt"] + coarse = x_history["coarse_prompt"] + fine = x_history["fine_prompt"] + return semantic, coarse, fine + + if voice == "random": + return None, None + + voices = get_voices(extra_voice_dirs) + try: + path = voices[voice] + except KeyError: + raise KeyError(f"Voice {voice} not found in {extra_voice_dirs}") + prompt = load_npz(path) + return prompt + + +def zero_crossing_rate(audio, frame_length=1024, hop_length=512): + zero_crossings = np.sum(np.abs(np.diff(np.sign(audio))) / 2) + total_frames = 1 + int((len(audio) - frame_length) / hop_length) + return zero_crossings / total_frames + + +def compute_spectral_contrast(audio_data, sample_rate, n_bands=6, fmin=200.0): + spectral_contrast = librosa.feature.spectral_contrast(y=audio_data, sr=sample_rate, n_bands=n_bands, fmin=fmin) + return np.mean(spectral_contrast) + + +def compute_average_bass_energy(audio_data, sample_rate, max_bass_freq=250): + stft = librosa.stft(audio_data) + power_spectrogram = np.abs(stft) ** 2 + frequencies = librosa.fft_frequencies(sr=sample_rate, n_fft=stft.shape[0]) + bass_mask = frequencies <= max_bass_freq + bass_energy = power_spectrogram[np.ix_(bass_mask, np.arange(power_spectrogram.shape[1]))].mean() + return bass_energy + + +def generate_voice( + audio, + text, + model, + output_path, +): + """Generate a new voice from a given audio and text prompt. + + Args: + audio (np.ndarray): The audio to use as a base for the new voice. + text (str): Transcription of the audio you are clonning. + model (BarkModel): The BarkModel to use for generating the new voice. + output_path (str): The path to save the generated voice to. + """ + if isinstance(audio, str): + audio, sr = torchaudio.load(audio) + audio = convert_audio(audio, sr, model.config.sample_rate, model.encodec.channels) + audio = audio.unsqueeze(0).to(model.device) + + with torch.no_grad(): + encoded_frames = model.encodec.encode(audio) + codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T] + + # get seconds of audio + seconds = audio.shape[-1] / model.config.sample_rate + + # move codes to cpu + codes = codes.cpu().numpy() + + # generate semantic tokens + # Load the HuBERT model + hubert_manager = HubertManager() + hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"]) + hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]) + + hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device) + + # Load the CustomTokenizer model + tokenizer = HubertTokenizer.load_from_checkpoint(model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]).to( + model.device + ) # Automatically uses + # semantic_tokens = model.text_to_semantic( + # text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7 + # ) # not 100% + semantic_vectors = hubert_model.forward(audio[0], input_sample_hz=model.config.sample_rate) + semantic_tokens = tokenizer.get_token(semantic_vectors) + semantic_tokens = semantic_tokens.cpu().numpy() + + np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens) + + # while attempts < max_attempts: + # if attempts > 0 and base is not None: + # # Reset the base model token + # print(f"Reset the base model token Regenerating...") + # base = None + + # audio_array, x = model.generate_audio(text, history_promp=None, base=base, **kwargs) + # zcr = zero_crossing_rate(audio_array) + # spectral_contrast = compute_spectral_contrast(audio_array, model.config.sample_rate) + # bass_energy = compute_average_bass_energy(audio_array, model.config.sample_rate) + # print(f"Attempt {attempts + 1}: ZCR = {zcr}, Spectral Contrast = {spectral_contrast:.2f}, Bass Energy = {bass_energy:.2f}") + + # # Save the audio array to the output_array directory with a random name for debugging + # #output_file = os.path.join(output_directory, f"audio_{zcr:.2f}_sc{spectral_contrast:.2f}_be{bass_energy:.2f}.wav") + # #wavfile.write(output_file, sample_rate, audio_array) + # #print(f"Saved audio array to {output_file}") + + # if zcr < zcr_threshold and spectral_contrast < spectral_threshold and bass_energy < bass_energy_threshold: + # print(f"Audio passed ZCR, Spectral Contrast, and Bass Energy thresholds. No need to regenerate.") + # break + # else: + # print(f"Audio failed ZCR, Spectral Contrast, and/or Bass Energy thresholds. Regenerating...") + + # attempts += 1 + + # if attempts == max_attempts: + # print("Reached maximum attempts. Returning the last generated audio.") + + # return audio_array, x, zcr, spectral_contrast, bass_energy + + +def generate_text_semantic( + text, + model, + history_prompt=None, + temp=0.7, + top_k=None, + top_p=None, + silent=False, + min_eos_p=0.2, + max_gen_duration_s=None, + allow_early_stop=True, + base=None, + use_kv_caching=True, +): + """Generate semantic tokens from text.""" + print(f"history_prompt in gen: {history_prompt}") + assert isinstance(text, str) + text = _normalize_whitespace(text) + assert len(text.strip()) > 0 + if history_prompt is not None or base is not None: + if history_prompt is not None: + semantic_history = history_prompt[0] + if base is not None: + semantic_history = base[0] + assert ( + isinstance(semantic_history, np.ndarray) + and len(semantic_history.shape) == 1 + and len(semantic_history) > 0 + and semantic_history.min() >= 0 + and semantic_history.max() <= model.config.SEMANTIC_VOCAB_SIZE - 1 + ) + else: + semantic_history = None + encoded_text = np.array(_tokenize(model.tokenizer, text)) + model.config.TEXT_ENCODING_OFFSET + if len(encoded_text) > 256: + p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1) + logger.warning(f"warning, text too long, lopping of last {p}%") + encoded_text = encoded_text[:256] + encoded_text = np.pad( + encoded_text, + (0, 256 - len(encoded_text)), + constant_values=model.config.TEXT_PAD_TOKEN, + mode="constant", + ) + if semantic_history is not None: + semantic_history = semantic_history.astype(np.int64) + # lop off if history is too long, pad if needed + semantic_history = semantic_history[-256:] + semantic_history = np.pad( + semantic_history, + (0, 256 - len(semantic_history)), + constant_values=model.config.SEMANTIC_PAD_TOKEN, + mode="constant", + ) + else: + semantic_history = np.array([model.config.SEMANTIC_PAD_TOKEN] * 256) + x = torch.from_numpy( + np.hstack([encoded_text, semantic_history, np.array([model.config.SEMANTIC_INFER_TOKEN])]).astype(np.int64) + )[None] + assert x.shape[1] == 256 + 256 + 1 + with _inference_mode(): + x = x.to(model.device) + n_tot_steps = 768 + # custom tqdm updates since we don't know when eos will occur + pbar = tqdm.tqdm(disable=silent, total=100) + pbar_state = 0 + tot_generated_duration_s = 0 + kv_cache = None + for n in range(n_tot_steps): + if use_kv_caching and kv_cache is not None: + x_input = x[:, [-1]] + else: + x_input = x + logits, kv_cache = model.semantic_model( + x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache + ) + relevant_logits = logits[0, 0, : model.config.SEMANTIC_VOCAB_SIZE] + if allow_early_stop: + relevant_logits = torch.hstack( + (relevant_logits, logits[0, 0, [model.config.SEMANTIC_PAD_TOKEN]]) + ) # eos + if top_p is not None: + # faster to convert to numpy + logits_device = relevant_logits.device + logits_dtype = relevant_logits.type() + relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() + sorted_indices = np.argsort(relevant_logits)[::-1] + sorted_logits = relevant_logits[sorted_indices] + cumulative_probs = np.cumsum(softmax(sorted_logits)) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy() + sorted_indices_to_remove[0] = False + relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf + relevant_logits = torch.from_numpy(relevant_logits) + relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + if top_k is not None: + v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) + relevant_logits[relevant_logits < v[-1]] = -float("Inf") + probs = torch.softmax(relevant_logits / temp, dim=-1) + item_next = torch.multinomial(probs, num_samples=1) + if allow_early_stop and ( + item_next == model.config.SEMANTIC_VOCAB_SIZE or (min_eos_p is not None and probs[-1] >= min_eos_p) + ): + # eos found, so break + pbar.update(100 - pbar_state) + break + x = torch.cat((x, item_next[None]), dim=1) + tot_generated_duration_s += 1 / model.config.SEMANTIC_RATE_HZ + if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s: + pbar.update(100 - pbar_state) + break + if n == n_tot_steps - 1: + pbar.update(100 - pbar_state) + break + del logits, relevant_logits, probs, item_next + req_pbar_state = np.min([100, int(round(100 * n / n_tot_steps))]) + if req_pbar_state > pbar_state: + pbar.update(req_pbar_state - pbar_state) + pbar_state = req_pbar_state + pbar.close() + out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :] + assert all(0 <= out) and all(out < model.config.SEMANTIC_VOCAB_SIZE) + _clear_cuda_cache() + return out + + +def _flatten_codebooks(arr, offset_size): + assert len(arr.shape) == 2 + arr = arr.copy() + if offset_size is not None: + for n in range(1, arr.shape[0]): + arr[n, :] += offset_size * n + flat_arr = arr.ravel("F") + return flat_arr + + +def generate_coarse( + x_semantic, + model, + history_prompt=None, + temp=0.7, + top_k=None, + top_p=None, + silent=False, + max_coarse_history=630, # min 60 (faster), max 630 (more context) + sliding_window_len=60, + base=None, + use_kv_caching=True, +): + """Generate coarse audio codes from semantic tokens.""" + assert ( + isinstance(x_semantic, np.ndarray) + and len(x_semantic.shape) == 1 + and len(x_semantic) > 0 + and x_semantic.min() >= 0 + and x_semantic.max() <= model.config.SEMANTIC_VOCAB_SIZE - 1 + ) + assert 60 <= max_coarse_history <= 630 + assert max_coarse_history + sliding_window_len <= 1024 - 256 + semantic_to_coarse_ratio = ( + model.config.COARSE_RATE_HZ / model.config.SEMANTIC_RATE_HZ * model.config.N_COARSE_CODEBOOKS + ) + max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) + if history_prompt is not None or base is not None: + if history_prompt is not None: + x_history = history_prompt + x_semantic_history = x_history[0] + x_coarse_history = x_history[1] + if base is not None: + x_semantic_history = base[0] + x_coarse_history = base[1] + assert ( + isinstance(x_semantic_history, np.ndarray) + and len(x_semantic_history.shape) == 1 + and len(x_semantic_history) > 0 + and x_semantic_history.min() >= 0 + and x_semantic_history.max() <= model.config.SEMANTIC_VOCAB_SIZE - 1 + and isinstance(x_coarse_history, np.ndarray) + and len(x_coarse_history.shape) == 2 + and x_coarse_history.shape[0] == model.config.N_COARSE_CODEBOOKS + and x_coarse_history.shape[-1] >= 0 + and x_coarse_history.min() >= 0 + and x_coarse_history.max() <= model.config.CODEBOOK_SIZE - 1 + and ( + round(x_coarse_history.shape[-1] / len(x_semantic_history), 1) + == round(semantic_to_coarse_ratio / model.config.N_COARSE_CODEBOOKS, 1) + ) + ) + x_coarse_history = ( + _flatten_codebooks(x_coarse_history, model.config.CODEBOOK_SIZE) + model.config.SEMANTIC_VOCAB_SIZE + ) + # trim histories correctly + n_semantic_hist_provided = np.min( + [ + max_semantic_history, + len(x_semantic_history) - len(x_semantic_history) % 2, + int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)), + ] + ) + n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio)) + x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32) + x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32) + # TODO: bit of a hack for time alignment (sounds better) + x_coarse_history = x_coarse_history[:-2] + else: + x_semantic_history = np.array([], dtype=np.int32) + x_coarse_history = np.array([], dtype=np.int32) + # start loop + n_steps = int( + round( + np.floor(len(x_semantic) * semantic_to_coarse_ratio / model.config.N_COARSE_CODEBOOKS) + * model.config.N_COARSE_CODEBOOKS + ) + ) + assert n_steps > 0 and n_steps % model.config.N_COARSE_CODEBOOKS == 0 + x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32) + x_coarse = x_coarse_history.astype(np.int32) + base_semantic_idx = len(x_semantic_history) + with _inference_mode(): + x_semantic_in = torch.from_numpy(x_semantic)[None].to(model.device) + x_coarse_in = torch.from_numpy(x_coarse)[None].to(model.device) + n_window_steps = int(np.ceil(n_steps / sliding_window_len)) + n_step = 0 + for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent): + semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio)) + # pad from right side + x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :] + x_in = x_in[:, :256] + x_in = F.pad( + x_in, + (0, 256 - x_in.shape[-1]), + "constant", + model.config.COARSE_SEMANTIC_PAD_TOKEN, + ) + x_in = torch.hstack( + [ + x_in, + torch.tensor([model.config.COARSE_INFER_TOKEN])[None].to(model.device), + x_coarse_in[:, -max_coarse_history:], + ] + ) + kv_cache = None + for _ in range(sliding_window_len): + if n_step >= n_steps: + continue + is_major_step = n_step % model.config.N_COARSE_CODEBOOKS == 0 + + if use_kv_caching and kv_cache is not None: + x_input = x_in[:, [-1]] + else: + x_input = x_in + + logits, kv_cache = model.coarse_model(x_input, use_cache=use_kv_caching, past_kv=kv_cache) + logit_start_idx = ( + model.config.SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * model.config.CODEBOOK_SIZE + ) + logit_end_idx = model.config.SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * model.config.CODEBOOK_SIZE + relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx] + if top_p is not None: + # faster to convert to numpy + logits_device = relevant_logits.device + logits_dtype = relevant_logits.type() + relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() + sorted_indices = np.argsort(relevant_logits)[::-1] + sorted_logits = relevant_logits[sorted_indices] + cumulative_probs = np.cumsum(torch.nn.functional.softmax(sorted_logits)) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy() + sorted_indices_to_remove[0] = False + relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf + relevant_logits = torch.from_numpy(relevant_logits) + relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + if top_k is not None: + v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) + relevant_logits[relevant_logits < v[-1]] = -float("Inf") + probs = torch.nn.functional.softmax(relevant_logits / temp, dim=-1) + item_next = torch.multinomial(probs, num_samples=1) + item_next += logit_start_idx + x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1) + x_in = torch.cat((x_in, item_next[None]), dim=1) + del logits, relevant_logits, probs, item_next + n_step += 1 + del x_in + del x_semantic_in + gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :] + del x_coarse_in + assert len(gen_coarse_arr) == n_steps + gen_coarse_audio_arr = ( + gen_coarse_arr.reshape(-1, model.config.N_COARSE_CODEBOOKS).T - model.config.SEMANTIC_VOCAB_SIZE + ) + for n in range(1, model.config.N_COARSE_CODEBOOKS): + gen_coarse_audio_arr[n, :] -= n * model.config.CODEBOOK_SIZE + _clear_cuda_cache() + return gen_coarse_audio_arr + + +def generate_fine( + x_coarse_gen, + model, + history_prompt=None, + temp=0.5, + silent=True, + base=None, +): + """Generate full audio codes from coarse audio codes.""" + assert ( + isinstance(x_coarse_gen, np.ndarray) + and len(x_coarse_gen.shape) == 2 + and 1 <= x_coarse_gen.shape[0] <= model.config.N_FINE_CODEBOOKS - 1 + and x_coarse_gen.shape[1] > 0 + and x_coarse_gen.min() >= 0 + and x_coarse_gen.max() <= model.config.CODEBOOK_SIZE - 1 + ) + if history_prompt is not None or base is not None: + if history_prompt is not None: + x_fine_history = history_prompt[2] + if base is not None: + x_fine_history = base[2] + assert ( + isinstance(x_fine_history, np.ndarray) + and len(x_fine_history.shape) == 2 + and x_fine_history.shape[0] == model.config.N_FINE_CODEBOOKS + and x_fine_history.shape[1] >= 0 + and x_fine_history.min() >= 0 + and x_fine_history.max() <= model.config.CODEBOOK_SIZE - 1 + ) + else: + x_fine_history = None + n_coarse = x_coarse_gen.shape[0] + # make input arr + in_arr = np.vstack( + [ + x_coarse_gen, + np.zeros((model.config.N_FINE_CODEBOOKS - n_coarse, x_coarse_gen.shape[1])) + + model.config.CODEBOOK_SIZE, # padding + ] + ).astype(np.int32) + # prepend history if available (max 512) + if x_fine_history is not None: + x_fine_history = x_fine_history.astype(np.int32) + in_arr = np.hstack( + [ + x_fine_history[:, -512:].astype(np.int32), + in_arr, + ] + ) + n_history = x_fine_history[:, -512:].shape[1] + else: + n_history = 0 + n_remove_from_end = 0 + # need to pad if too short (since non-causal model) + if in_arr.shape[1] < 1024: + n_remove_from_end = 1024 - in_arr.shape[1] + in_arr = np.hstack( + [ + in_arr, + np.zeros((model.config.N_FINE_CODEBOOKS, n_remove_from_end), dtype=np.int32) + + model.config.CODEBOOK_SIZE, + ] + ) + # we can be lazy about fractional loop and just keep overwriting codebooks + n_loops = np.max([0, int(np.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))]) + 1 + with _inference_mode(): + in_arr = torch.tensor(in_arr.T).to(model.device) + for n in tqdm.tqdm(range(n_loops), disable=silent): + start_idx = np.min([n * 512, in_arr.shape[0] - 1024]) + start_fill_idx = np.min([n_history + n * 512, in_arr.shape[0] - 512]) + rel_start_fill_idx = start_fill_idx - start_idx + in_buffer = in_arr[start_idx : start_idx + 1024, :][None] + for nn in range(n_coarse, model.config.N_FINE_CODEBOOKS): + logits = model.fine_model(nn, in_buffer) + if temp is None: + relevant_logits = logits[0, rel_start_fill_idx:, : model.config.CODEBOOK_SIZE] + codebook_preds = torch.argmax(relevant_logits, -1) + else: + relevant_logits = logits[0, :, : model.config.CODEBOOK_SIZE] / temp + probs = F.softmax(relevant_logits, dim=-1) + codebook_preds = torch.hstack( + [torch.multinomial(probs[n], num_samples=1) for n in range(rel_start_fill_idx, 1024)] + ) + in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds + del logits, codebook_preds + # transfer over info into model_in and convert to numpy + for nn in range(n_coarse, model.config.N_FINE_CODEBOOKS): + in_arr[start_fill_idx : start_fill_idx + (1024 - rel_start_fill_idx), nn] = in_buffer[ + 0, rel_start_fill_idx:, nn + ] + del in_buffer + gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T + del in_arr + gen_fine_arr = gen_fine_arr[:, n_history:] + if n_remove_from_end > 0: + gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end] + assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1] + _clear_cuda_cache() + return gen_fine_arr + + +def codec_decode(fine_tokens, model): + """Turn quantized audio codes into audio array using encodec.""" + from TTS.utils.audio.numpy_transforms import save_wav + + arr = torch.from_numpy(fine_tokens)[None] + arr = arr.to(model.device) + arr = arr.transpose(0, 1) + emb = model.encodec.quantizer.decode(arr) + out = model.encodec.decoder(emb) + audio_arr = out.detach().cpu().numpy().squeeze() + save_wav(path="test.wav", wav=audio_arr, sample_rate=model.config.sample_rate) diff --git a/TTS/tts/layers/bark/load_model.py b/TTS/tts/layers/bark/load_model.py new file mode 100644 index 00000000..dbd861d0 --- /dev/null +++ b/TTS/tts/layers/bark/load_model.py @@ -0,0 +1,254 @@ +import contextlib + +# import funcy +import functools +import hashlib +import logging +import os +import re + +import requests +import torch +import tqdm +from encodec import EncodecModel +from transformers import BertTokenizer + +from TTS.tts.layers.bark.model import GPT, GPTConfig +from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig + +if ( + torch.cuda.is_available() + and hasattr(torch.cuda, "amp") + and hasattr(torch.cuda.amp, "autocast") + and torch.cuda.is_bf16_supported() +): + autocast = functools.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) +else: + + @contextlib.contextmanager + def autocast(): + yield + + +# hold models in global scope to lazy load +global models +models = {} + +logger = logging.getLogger(__name__) + + +if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + logger.warning( + "torch version does not support flash attention. You will get significantly faster" + + " inference speed by upgrade torch to newest version / nightly." + ) + + +def _string_md5(s): + m = hashlib.md5() + m.update(s.encode("utf-8")) + return m.hexdigest() + + +def _md5(fname): + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +def _get_ckpt_path(model_type, CACHE_DIR): + model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"]) + return os.path.join(CACHE_DIR, f"{model_name}.pt") + + +S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/" + + +def _parse_s3_filepath(s3_filepath): + bucket_name = re.search(S3_BUCKET_PATH_RE, s3_filepath).group(1) + rel_s3_filepath = re.sub(S3_BUCKET_PATH_RE, "", s3_filepath) + return bucket_name, rel_s3_filepath + + +def _download(from_s3_path, to_local_path, CACHE_DIR): + os.makedirs(CACHE_DIR, exist_ok=True) + response = requests.get(from_s3_path, stream=True) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + with open(to_local_path, "wb") as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + raise ValueError("ERROR, something went wrong") + + +class InferenceContext: + def __init__(self, benchmark=False): + # we can't expect inputs to be the same length, so disable benchmarking by default + self._chosen_cudnn_benchmark = benchmark + self._cudnn_benchmark = None + + def __enter__(self): + self._cudnn_benchmark = torch.backends.cudnn.benchmark + torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark + + def __exit__(self, exc_type, exc_value, exc_traceback): + torch.backends.cudnn.benchmark = self._cudnn_benchmark + + +if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +@contextlib.contextmanager +def _inference_mode(): + with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast(): + yield + + +def _clear_cuda_cache(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def clean_models(model_key=None): + global models + model_keys = [model_key] if model_key is not None else models.keys() + for k in model_keys: + if k in models: + del models[k] + _clear_cuda_cache() + + +def _load_model(ckpt_path, device, config, model_type="text"): + logger.info(f"loading {model_type} model from {ckpt_path}...") + + if device == "cpu": + logger.warning("No GPU being used. Careful, Inference might be extremely slow!") + if model_type == "text": + ConfigClass = GPTConfig + ModelClass = GPT + elif model_type == "coarse": + ConfigClass = GPTConfig + ModelClass = GPT + elif model_type == "fine": + ConfigClass = FineGPTConfig + ModelClass = FineGPT + else: + raise NotImplementedError() + if ( + not config.USE_SMALLER_MODELS + and os.path.exists(ckpt_path) + and _md5(ckpt_path) != config.REMOTE_MODEL_PATHS[model_type]["checksum"] + ): + logger.warning(f"found outdated {model_type} model, removing...") + os.remove(ckpt_path) + if not os.path.exists(ckpt_path): + logger.info(f"{model_type} model not found, downloading...") + _download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR) + + checkpoint = torch.load(ckpt_path, map_location=device) + # this is a hack + model_args = checkpoint["model_args"] + if "input_vocab_size" not in model_args: + model_args["input_vocab_size"] = model_args["vocab_size"] + model_args["output_vocab_size"] = model_args["vocab_size"] + del model_args["vocab_size"] + + gptconf = ConfigClass(**checkpoint["model_args"]) + if model_type == "text": + config.semantic_config = gptconf + elif model_type == "coarse": + config.coarse_config = gptconf + elif model_type == "fine": + config.fine_config = gptconf + + model = ModelClass(gptconf) + state_dict = checkpoint["model"] + # fixup checkpoint + unwanted_prefix = "_orig_mod." + for k, v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) + extra_keys = set(state_dict.keys()) - set(model.state_dict().keys()) + extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")]) + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")]) + if len(extra_keys) != 0: + raise ValueError(f"extra keys found: {extra_keys}") + if len(missing_keys) != 0: + raise ValueError(f"missing keys: {missing_keys}") + model.load_state_dict(state_dict, strict=False) + n_params = model.get_num_params() + val_loss = checkpoint["best_val_loss"].item() + logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") + model.eval() + model.to(device) + del checkpoint, state_dict + _clear_cuda_cache() + return model, config + + +def _load_codec_model(device): + model = EncodecModel.encodec_model_24khz() + model.set_target_bandwidth(6.0) + model.eval() + model.to(device) + _clear_cuda_cache() + return model + + +def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"): + _load_model_f = functools.partial(_load_model, model_type=model_type) + if model_type not in ("text", "coarse", "fine"): + raise NotImplementedError() + global models + if torch.cuda.device_count() == 0 or not use_gpu: + device = "cpu" + else: + device = "cuda" + model_key = str(device) + f"__{model_type}" + if model_key not in models or force_reload: + if ckpt_path is None: + ckpt_path = _get_ckpt_path(model_type) + clean_models(model_key=model_key) + model = _load_model_f(ckpt_path, device) + models[model_key] = model + return models[model_key] + + +def load_codec_model(use_gpu=True, force_reload=False): + global models + if torch.cuda.device_count() == 0 or not use_gpu: + device = "cpu" + else: + device = "cuda" + model_key = str(device) + f"__codec" + if model_key not in models or force_reload: + clean_models(model_key=model_key) + model = _load_codec_model(device) + models[model_key] = model + return models[model_key] + + +def preload_models( + text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True, use_smaller_models=False +): + global USE_SMALLER_MODELS + global REMOTE_MODEL_PATHS + if use_smaller_models: + USE_SMALLER_MODELS = True + logger.info("Using smaller models generation.py") + REMOTE_MODEL_PATHS = SMALL_REMOTE_MODEL_PATHS + + _ = load_model(ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True) + _ = load_model(ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True) + _ = load_model(ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True) + _ = load_codec_model(use_gpu=use_gpu, force_reload=True) diff --git a/TTS/tts/layers/bark/model.py b/TTS/tts/layers/bark/model.py new file mode 100644 index 00000000..485e6665 --- /dev/null +++ b/TTS/tts/layers/bark/model.py @@ -0,0 +1,232 @@ +""" +Much of this code is adapted from Andrej Karpathy's NanoGPT +(https://github.com/karpathy/nanoGPT) +""" +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class LayerNorm(nn.Module): + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" + + def __init__(self, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + + def forward(self, input): + return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) + + +class CausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") + if not self.flash: + # print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0") + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "bias", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + ) + + def forward(self, x, past_kv=None, use_cache=False): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + if past_kv is not None: + past_key = past_kv[0] + past_value = past_kv[1] + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + FULL_T = k.shape[-2] + + if use_cache is True: + present = (k, v) + else: + present = None + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + if past_kv is not None: + # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains + # the query for the last token. scaled_dot_product_attention interprets this as the first token in the + # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so + # to work around this we set is_causal=False. + is_causal = False + else: + is_causal = True + + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, FULL_T - T : FULL_T, :FULL_T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return (y, present) + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + self.gelu = nn.GELU() + + def forward(self, x): + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + + +class Block(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + self.layer_idx = layer_idx + + def forward(self, x, past_kv=None, use_cache=False): + attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache) + x = x + attn_output + x = x + self.mlp(self.ln_2(x)) + return (x, prev_kvs) + + +@dataclass +class GPTConfig: + block_size: int = 1024 + input_vocab_size: int = 10_048 + output_vocab_size: int = 10_048 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + assert config.input_vocab_size is not None + assert config.output_vocab_size is not None + assert config.block_size is not None + self.config = config + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.input_vocab_size, config.n_embd), + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]), + ln_f=LayerNorm(config.n_embd, bias=config.bias), + ) + ) + self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False) + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.transformer.wte.weight.numel() + n_params -= self.transformer.wpe.weight.numel() + return n_params + + def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): + device = idx.device + b, t = idx.size() + if past_kv is not None: + assert t == 1 + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + else: + if merge_context: + assert idx.shape[1] >= 256 + 256 + 1 + t = idx.shape[1] - 256 + else: + assert ( + t <= self.config.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # forward the GPT model itself + if merge_context: + tok_emb = torch.cat( + [ + self.transformer.wte(idx[:, :256]) + self.transformer.wte(idx[:, 256 : 256 + 256]), + self.transformer.wte(idx[:, 256 + 256 :]), + ], + dim=1, + ) + else: + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + if past_kv is None: + past_length = 0 + past_kv = tuple([None] * len(self.transformer.h)) + else: + past_length = past_kv[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) # shape (1, t) + assert position_ids.shape == (1, t) + + pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) + + x = self.transformer.drop(tok_emb + pos_emb) + + new_kv = () if use_cache else None + + for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): + x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache) + + if use_cache: + new_kv = new_kv + (kv,) + + x = self.transformer.ln_f(x) + + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim + + return (logits, new_kv) diff --git a/TTS/tts/layers/bark/model_fine.py b/TTS/tts/layers/bark/model_fine.py new file mode 100644 index 00000000..8a426107 --- /dev/null +++ b/TTS/tts/layers/bark/model_fine.py @@ -0,0 +1,142 @@ +""" +Much of this code is adapted from Andrej Karpathy's NanoGPT +(https://github.com/karpathy/nanoGPT) +""" +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from .model import GPT, MLP, GPTConfig + + +class NonCausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0 + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False + ) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class FineBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.ln_1 = nn.LayerNorm(config.n_embd) + self.attn = NonCausalSelfAttention(config) + self.ln_2 = nn.LayerNorm(config.n_embd) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class FineGPT(GPT): + def __init__(self, config): + super().__init__(config) + del self.lm_head + self.config = config + self.n_codes_total = config.n_codes_total + self.transformer = nn.ModuleDict( + dict( + wtes=nn.ModuleList( + [nn.Embedding(config.input_vocab_size, config.n_embd) for _ in range(config.n_codes_total)] + ), + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]), + ln_f=nn.LayerNorm(config.n_embd), + ) + ) + self.lm_heads = nn.ModuleList( + [ + nn.Linear(config.n_embd, config.output_vocab_size, bias=False) + for _ in range(config.n_codes_given, self.n_codes_total) + ] + ) + for i in range(self.n_codes_total - config.n_codes_given): + self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight + + def forward(self, pred_idx, idx): + device = idx.device + b, t, codes = idx.size() + assert ( + t <= self.config.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + assert pred_idx > 0, "cannot predict 0th codebook" + assert codes == self.n_codes_total, (b, t, codes) + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + + # forward the GPT model itself + tok_embs = [ + wte(idx[:, :, i]).unsqueeze(-1) for i, wte in enumerate(self.transformer.wtes) + ] # token embeddings of shape (b, t, n_embd) + tok_emb = torch.cat(tok_embs, dim=-1) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) + x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1) + x = self.transformer.drop(x + pos_emb) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + logits = self.lm_heads[pred_idx - self.config.n_codes_given](x) + return logits + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + for wte in self.transformer.wtes: + n_params -= wte.weight.numel() + n_params -= self.transformer.wpe.weight.numel() + return n_params + + +@dataclass +class FineGPTConfig(GPTConfig): + n_codes_total: int = 8 + n_codes_given: int = 1 From 5a31fad5028f3e819caac2f51761f47bda2129ee Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 19 Jun 2023 14:14:04 +0200 Subject: [PATCH 03/29] Download HF models --- TTS/utils/manage.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 0d0b9064..dc0c7b68 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -245,6 +245,26 @@ class ModelManager(object): else: print(" > Model's license - No license information available") + def _download_github_model(self, model_item: Dict, output_path: str): + if isinstance(model_item["github_rls_url"], list): + self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) + + def _download_hf_model(self, model_item:Dict, output_path: str): + if isinstance(model_item["hf_url"], list): + self._download_model_files(model_item["hf_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar) + + def set_model_url(self, model_item: Dict): + model_item["model_url"] = None + if "github_rls_url" in model_item: + model_item["model_url"] = model_item["github_rls_url"] + elif "hf_url" in model_item: + model_item["model_url"] = model_item["hf_url"] + return model_item + def download_model(self, model_name): """Download model files given the full model name. Model name is in the format @@ -264,6 +284,7 @@ class ModelManager(object): model_full_name = f"{model_type}--{lang}--{dataset}--{model}" model_item = self.models_dict[model_type][lang][dataset][model] model_item["model_type"] = model_type + model_item = self.set_model_url(model_item) # set the model specific output path output_path = os.path.join(self.output_prefix, model_full_name) if os.path.exists(output_path): @@ -271,16 +292,16 @@ class ModelManager(object): else: os.makedirs(output_path, exist_ok=True) print(f" > Downloading model to {output_path}") - # download from github release - if isinstance(model_item["github_rls_url"], list): - self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) - else: - self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) + if "github_rls_url" in model_item: + self._download_github_model(model_item, output_path) + elif "hf_url" in model_item: + self._download_hf_model(model_item, output_path) + self.print_model_license(model_item=model_item) # find downloaded files output_model_path = output_path output_config_path = None - if model != "tortoise-v2": + if model not in ["tortoise-v2", "bark"]: # TODO:This is stupid but don't care for now. output_model_path, output_config_path = self._find_files(output_path) # update paths in the config.json self._update_paths(output_path, output_config_path) From 2364c38d169770cc1e95381a9480cd8ed2b1a62d Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 19 Jun 2023 14:15:21 +0200 Subject: [PATCH 04/29] Update synthesizer --- TTS/utils/synthesizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 1b91521b..4f7761b9 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -349,7 +349,7 @@ class Synthesizer(object): text=sen, config=self.tts_config, speaker_id=sp_name, - extra_voice_dirs=self.voice_dir, + voice_dirs=self.voice_dir, **kwargs, ) else: From 37b708dac75b744fb14d60e0a45f6b4c59b416ad Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 19 Jun 2023 14:16:06 +0200 Subject: [PATCH 05/29] Add bark model --- TTS/.models.json | 159 +++++++++++++------------ TTS/api.py | 4 +- TTS/tts/configs/bark_config.py | 5 +- TTS/tts/layers/bark/inference_funcs.py | 10 +- TTS/tts/layers/bark/model.py | 3 +- 5 files changed, 100 insertions(+), 81 deletions(-) diff --git a/TTS/.models.json b/TTS/.models.json index b396e641..801485a1 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -1,20 +1,33 @@ { "tts_models": { - "multilingual":{ - "multi-dataset":{ - "your_tts":{ + "multilingual": { + "multi-dataset": { + "your_tts": { "description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--multilingual--multi-dataset--your_tts.zip", "default_vocoder": null, "commit": "e9a1953e", "license": "CC BY-NC-ND 4.0", "contact": "egolge@coqui.ai" + }, + "bark": { + "description": "🐶 Bark TTS model released by suno-ai. You can find the original implementation in https://github.com/suno-ai/bark.", + "hf_url": [ + "https://coqui.gateway.scarf.sh/bark/coarse_2.pt", + "https://coqui.gateway.scarf.sh/bark/fine_2.pt", + "https://coqui.gateway.scarf.sh/bark/text_2.pt", + "https://coqui.gateway.scarf.sh/bark/config.json" + ], + "default_vocoder": null, + "commit": "e9a1953e", + "license": "MIT", + "contact": "https://www.suno.ai/" } } }, "bg": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--bg--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -25,7 +38,7 @@ }, "cs": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--cs--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -36,7 +49,7 @@ }, "da": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--da--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -47,7 +60,7 @@ }, "et": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--et--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -58,7 +71,7 @@ }, "ga": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--ga--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -180,7 +193,7 @@ "license": "apache 2.0", "contact": "egolge@coqui.ai" }, - "fast_pitch":{ + "fast_pitch": { "description": "FastPitch model trained on VCTK dataseset.", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--vctk--fast_pitch.zip", "default_vocoder": null, @@ -220,21 +233,21 @@ "license": "apache 2.0", "contact": "adamfroghyar@gmail.com" } - }, - "multi-dataset":{ - "tortoise-v2":{ + "multi-dataset": { + "tortoise-v2": { "description": "Tortoise tts model https://github.com/neonbjb/tortoise-tts", - "github_rls_url": ["https://coqui.gateway.scarf.sh/v0.14.1_models/autoregressive.pth", - "https://coqui.gateway.scarf.sh/v0.14.1_models/clvp2.pth", - "https://coqui.gateway.scarf.sh/v0.14.1_models/cvvp.pth", - "https://coqui.gateway.scarf.sh/v0.14.1_models/diffusion_decoder.pth", - "https://coqui.gateway.scarf.sh/v0.14.1_models/rlg_auto.pth", - "https://coqui.gateway.scarf.sh/v0.14.1_models/rlg_diffuser.pth", - "https://coqui.gateway.scarf.sh/v0.14.1_models/vocoder.pth", - "https://coqui.gateway.scarf.sh/v0.14.1_models/mel_norms.pth", - "https://coqui.gateway.scarf.sh/v0.14.1_models/config.json" - ], + "github_rls_url": [ + "https://coqui.gateway.scarf.sh/v0.14.1_models/autoregressive.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/clvp2.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/cvvp.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/diffusion_decoder.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/rlg_auto.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/rlg_diffuser.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/vocoder.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/mel_norms.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/config.json" + ], "commit": "c1875f6", "default_vocoder": null, "author": "@neonbjb - James Betker, @manmay-nakhashi Manmay Nakhashi", @@ -242,7 +255,7 @@ } }, "jenny": { - "jenny":{ + "jenny": { "description": "VITS model trained with Jenny(Dioco) dataset. Named as Jenny as demanded by the license. Original URL for the model https://www.kaggle.com/datasets/noml4u/tts-models--en--jenny-dioco--vits", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.14.0_models/tts_models--en--jenny--jenny.zip", "default_vocoder": null, @@ -263,8 +276,8 @@ "contact": "egolge@coqui.com" } }, - "css10":{ - "vits":{ + "css10": { + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--es--css10--vits.zip", "default_vocoder": null, "commit": null, @@ -284,8 +297,8 @@ "contact": "egolge@coqui.com" } }, - "css10":{ - "vits":{ + "css10": { + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--fr--css10--vits.zip", "default_vocoder": null, "commit": null, @@ -294,17 +307,17 @@ } } }, - "uk":{ + "uk": { "mai": { "glow-tts": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--uk--mai--glow-tts.zip", - "author":"@robinhad", + "author": "@robinhad", "commit": "bdab788d", "license": "MIT", "contact": "", "default_vocoder": "vocoder_models/uk/mai/multiband-melgan" }, - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--uk--mai--vits.zip", "default_vocoder": null, "commit": null, @@ -335,8 +348,8 @@ "commit": "540d811" } }, - "css10":{ - "vits":{ + "css10": { + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--nl--css10--vits.zip", "default_vocoder": null, "commit": null, @@ -371,7 +384,7 @@ } }, "css10": { - "vits-neon":{ + "vits-neon": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--de--css10--vits.zip", "default_vocoder": null, "author": "@NeonGeckoCom", @@ -392,9 +405,9 @@ } } }, - "tr":{ + "tr": { "common-voice": { - "glow-tts":{ + "glow-tts": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--tr--common-voice--glow-tts.zip", "default_vocoder": "vocoder_models/tr/common-voice/hifigan", "license": "MIT", @@ -406,7 +419,7 @@ }, "it": { "mai_female": { - "glow-tts":{ + "glow-tts": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_female--glow-tts.zip", "default_vocoder": null, "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", @@ -414,7 +427,7 @@ "license": "apache 2.0", "commit": null }, - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_female--vits.zip", "default_vocoder": null, "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", @@ -424,7 +437,7 @@ } }, "mai_male": { - "glow-tts":{ + "glow-tts": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_male--glow-tts.zip", "default_vocoder": null, "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", @@ -432,7 +445,7 @@ "license": "apache 2.0", "commit": null }, - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_male--vits.zip", "default_vocoder": null, "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", @@ -444,7 +457,7 @@ }, "ewe": { "openbible": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--ewe--openbible--vits.zip", "default_vocoder": null, "license": "CC-BY-SA 4.0", @@ -456,7 +469,7 @@ }, "hau": { "openbible": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--hau--openbible--vits.zip", "default_vocoder": null, "license": "CC-BY-SA 4.0", @@ -468,7 +481,7 @@ }, "lin": { "openbible": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--lin--openbible--vits.zip", "default_vocoder": null, "license": "CC-BY-SA 4.0", @@ -480,7 +493,7 @@ }, "tw_akuapem": { "openbible": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--tw_akuapem--openbible--vits.zip", "default_vocoder": null, "license": "CC-BY-SA 4.0", @@ -492,7 +505,7 @@ }, "tw_asante": { "openbible": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--tw_asante--openbible--vits.zip", "default_vocoder": null, "license": "CC-BY-SA 4.0", @@ -504,7 +517,7 @@ }, "yor": { "openbible": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--yor--openbible--vits.zip", "default_vocoder": null, "license": "CC-BY-SA 4.0", @@ -538,7 +551,7 @@ }, "fi": { "css10": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--fi--css10--vits.zip", "default_vocoder": null, "commit": null, @@ -549,7 +562,7 @@ }, "hr": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--hr--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -560,7 +573,7 @@ }, "lt": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--lt--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -571,7 +584,7 @@ }, "lv": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--lv--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -582,7 +595,7 @@ }, "mt": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--mt--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -593,7 +606,7 @@ }, "pl": { "mai_female": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--pl--mai_female--vits.zip", "default_vocoder": null, "commit": null, @@ -604,7 +617,7 @@ }, "pt": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--pt--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -615,7 +628,7 @@ }, "ro": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--ro--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -626,7 +639,7 @@ }, "sk": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--sk--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -637,7 +650,7 @@ }, "sl": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--sl--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -648,7 +661,7 @@ }, "sv": { "cv": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--sv--cv--vits.zip", "default_vocoder": null, "commit": null, @@ -659,7 +672,7 @@ }, "ca": { "custom": { - "vits":{ + "vits": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--ca--custom--vits.zip", "default_vocoder": null, "commit": null, @@ -669,8 +682,8 @@ } } }, - "fa":{ - "custom":{ + "fa": { + "custom": { "glow-tts": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--fa--custom--glow-tts.zip", "default_vocoder": null, @@ -681,18 +694,18 @@ } } }, - "bn":{ - "custom":{ - "vits-male":{ - "github_rls_url":"https://coqui.gateway.scarf.sh/v0.13.3_models/tts_models--bn--custom--vits_male.zip", + "bn": { + "custom": { + "vits-male": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.13.3_models/tts_models--bn--custom--vits_male.zip", "default_vocoder": null, "commit": null, "description": "Single speaker Bangla male model. For more information -> https://github.com/mobassir94/comprehensive-bangla-tts", "author": "@mobassir94", "license": "Apache 2.0" }, - "vits-female":{ - "github_rls_url":"https://coqui.gateway.scarf.sh/v0.13.3_models/tts_models--bn--custom--vits_female.zip", + "vits-female": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.13.3_models/tts_models--bn--custom--vits_female.zip", "default_vocoder": null, "commit": null, "description": "Single speaker Bangla female model. For more information -> https://github.com/mobassir94/comprehensive-bangla-tts", @@ -834,16 +847,16 @@ "mai": { "multiband-melgan": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--uk--mai--multiband-melgan.zip", - "author":"@robinhad", + "author": "@robinhad", "commit": "bdab788d", "license": "MIT", "contact": "" } } }, - "tr":{ + "tr": { "common-voice": { - "hifigan":{ + "hifigan": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--tr--common-voice--hifigan.zip", "description": "HifiGAN model using an unknown speaker from the Common-Voice dataset.", "author": "Fatih Akademi", @@ -853,10 +866,10 @@ } } }, - "voice_conversion_models":{ - "multilingual":{ - "vctk":{ - "freevc24":{ + "voice_conversion_models": { + "multilingual": { + "vctk": { + "freevc24": { "github_rls_url": "https://coqui.gateway.scarf.sh/v0.13.0_models/voice_conversion_models--multilingual--vctk--freevc24.zip", "description": "FreeVC model trained on VCTK dataset from https://github.com/OlaWod/FreeVC", "author": "Jing-Yi Li @OlaWod", @@ -866,4 +879,4 @@ } } } -} +} \ No newline at end of file diff --git a/TTS/api.py b/TTS/api.py index 8bd087f6..190fe6b8 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -342,7 +342,7 @@ class TTS: def download_model_by_name(self, model_name: str): model_path, config_path, model_item = self.manager.download_model(model_name) - if isinstance(model_item["github_rls_url"], list): + if isinstance(model_item["model_url"], list): # return model directory if there are multiple files # we assume that the model knows how to load itself return None, None, None, None, model_path @@ -580,6 +580,8 @@ class TTS: Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None. file_path (str, optional): Output file path. Defaults to "output.wav". + kwargs (dict, optional): + Additional arguments for the model. """ self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs) diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py index 760776a8..57ccf2d0 100644 --- a/TTS/tts/configs/bark_config.py +++ b/TTS/tts/configs/bark_config.py @@ -5,11 +5,14 @@ from typing import Dict from TTS.tts.configs.shared_configs import BaseTTSConfig from TTS.tts.layers.bark.model import GPTConfig from TTS.tts.layers.bark.model_fine import FineGPTConfig +from TTS.tts.models.bark import BarkAudioConfig from TTS.utils.generic_utils import get_user_data_dir @dataclass class BarkConfig(BaseTTSConfig): + model: str = "bark" + audio: BarkAudioConfig = BarkAudioConfig() num_chars: int = 0 semantic_config: GPTConfig = GPTConfig() fine_config: FineGPTConfig = FineGPTConfig() @@ -31,7 +34,7 @@ class BarkConfig(BaseTTSConfig): COARSE_SEMANTIC_PAD_TOKEN: int = 12_048 COARSE_INFER_TOKEN: int = 12_050 - REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" + REMOTE_BASE_URL = "https://huggingface.co/erogol/bark/tree/main/" REMOTE_MODEL_PATHS: Dict = None LOCAL_MODEL_PATHS: Dict = None SMALL_REMOTE_MODEL_PATHS: Dict = None diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py index 73c9ee71..6fa87c37 100644 --- a/TTS/tts/layers/bark/inference_funcs.py +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -52,7 +52,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []): return semantic, coarse, fine if voice == "random": - return None, None + return None, None, None voices = get_voices(extra_voice_dirs) try: @@ -183,7 +183,7 @@ def generate_text_semantic( assert isinstance(text, str) text = _normalize_whitespace(text) assert len(text.strip()) > 0 - if history_prompt is not None or base is not None: + if all(v is not None for v in history_prompt) or base is not None: if history_prompt is not None: semantic_history = history_prompt[0] if base is not None: @@ -327,7 +327,7 @@ def generate_coarse( model.config.COARSE_RATE_HZ / model.config.SEMANTIC_RATE_HZ * model.config.N_COARSE_CODEBOOKS ) max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) - if history_prompt is not None or base is not None: + if all(v is not None for v in history_prompt) or base is not None: if history_prompt is not None: x_history = history_prompt x_semantic_history = x_history[0] @@ -477,7 +477,7 @@ def generate_fine( and x_coarse_gen.min() >= 0 and x_coarse_gen.max() <= model.config.CODEBOOK_SIZE - 1 ) - if history_prompt is not None or base is not None: + if all(v is not None for v in history_prompt) or base is not None: if history_prompt is not None: x_fine_history = history_prompt[2] if base is not None: @@ -572,4 +572,4 @@ def codec_decode(fine_tokens, model): emb = model.encodec.quantizer.decode(arr) out = model.encodec.decoder(emb) audio_arr = out.detach().cpu().numpy().squeeze() - save_wav(path="test.wav", wav=audio_arr, sample_rate=model.config.sample_rate) + return audio_arr diff --git a/TTS/tts/layers/bark/model.py b/TTS/tts/layers/bark/model.py index 485e6665..81117b3e 100644 --- a/TTS/tts/layers/bark/model.py +++ b/TTS/tts/layers/bark/model.py @@ -4,6 +4,7 @@ Much of this code is adapted from Andrej Karpathy's NanoGPT """ import math from dataclasses import dataclass +from coqpit import Coqpit import torch import torch.nn as nn @@ -131,7 +132,7 @@ class Block(nn.Module): @dataclass -class GPTConfig: +class GPTConfig(Coqpit): block_size: int = 1024 input_vocab_size: int = 10_048 output_vocab_size: int = 10_048 From f4c88ed677dc7a38e8aa3bc92026a37a2554d021 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 19 Jun 2023 14:22:32 +0200 Subject: [PATCH 06/29] Make style --- TTS/tts/layers/bark/model.py | 2 +- TTS/utils/manage.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/tts/layers/bark/model.py b/TTS/tts/layers/bark/model.py index 81117b3e..bcc87a4b 100644 --- a/TTS/tts/layers/bark/model.py +++ b/TTS/tts/layers/bark/model.py @@ -4,10 +4,10 @@ Much of this code is adapted from Andrej Karpathy's NanoGPT """ import math from dataclasses import dataclass -from coqpit import Coqpit import torch import torch.nn as nn +from coqpit import Coqpit from torch.nn import functional as F diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index dc0c7b68..f9968910 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -251,7 +251,7 @@ class ModelManager(object): else: self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) - def _download_hf_model(self, model_item:Dict, output_path: str): + def _download_hf_model(self, model_item: Dict, output_path: str): if isinstance(model_item["hf_url"], list): self._download_model_files(model_item["hf_url"], output_path, self.progress_bar) else: From e89aa970256071d9371abb7809e31466b966b544 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Wed, 21 Jun 2023 11:57:33 +0200 Subject: [PATCH 07/29] Update pylintrc --- .pylintrc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index d5f9c490..49a9dbdd 100644 --- a/.pylintrc +++ b/.pylintrc @@ -169,7 +169,9 @@ disable=missing-docstring, comprehension-escape, duplicate-code, not-callable, - import-outside-toplevel + import-outside-toplevel, + logging-fstring-interpolation, + logging-not-lazy # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option From 695e862aadc0f280c9853d8888ad2d4382fef34b Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Wed, 21 Jun 2023 11:57:46 +0200 Subject: [PATCH 08/29] Update model URLs --- TTS/.models.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/TTS/.models.json b/TTS/.models.json index 801485a1..c97c6a38 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -13,10 +13,10 @@ "bark": { "description": "🐶 Bark TTS model released by suno-ai. You can find the original implementation in https://github.com/suno-ai/bark.", "hf_url": [ - "https://coqui.gateway.scarf.sh/bark/coarse_2.pt", - "https://coqui.gateway.scarf.sh/bark/fine_2.pt", - "https://coqui.gateway.scarf.sh/bark/text_2.pt", - "https://coqui.gateway.scarf.sh/bark/config.json" + "https://coqui.gateway.scarf.sh/hf/bark/coarse_2.pt", + "https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt", + "https://coqui.gateway.scarf.sh/hf/bark/text_2.pt", + "https://coqui.gateway.scarf.sh/hf/bark/config.json" ], "default_vocoder": null, "commit": "e9a1953e", From 03c347b7f3f5af29027ad1919c809d4e6cf434c8 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Wed, 21 Jun 2023 11:58:18 +0200 Subject: [PATCH 09/29] Update Bark Config --- TTS/bin/synthesize.py | 2 +- TTS/tts/configs/bark_config.py | 39 ++++++++++++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 8a7e178d..0334c023 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -356,7 +356,7 @@ If you don't specify any models, then it uses LJSpeech based English model. vc_config_path = config_path # tts model with multiple files to be loaded from the directory path - if isinstance(model_item["github_rls_url"], list): + if isinstance(model_item["model_url"], list): model_dir = model_path tts_path = None tts_config_path = None diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py index 57ccf2d0..943f3dea 100644 --- a/TTS/tts/configs/bark_config.py +++ b/TTS/tts/configs/bark_config.py @@ -1,5 +1,5 @@ import os -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Dict from TTS.tts.configs.shared_configs import BaseTTSConfig @@ -11,6 +11,40 @@ from TTS.utils.generic_utils import get_user_data_dir @dataclass class BarkConfig(BaseTTSConfig): + """ Bark TTS configuration + + Args: + model (str): model name that registers the model. + audio (BarkAudioConfig): audio configuration. Defaults to BarkAudioConfig(). + num_chars (int): number of characters in the alphabet. Defaults to 0. + semantic_config (GPTConfig): semantic configuration. Defaults to GPTConfig(). + fine_config (FineGPTConfig): fine configuration. Defaults to FineGPTConfig(). + coarse_config (GPTConfig): coarse configuration. Defaults to GPTConfig(). + CONTEXT_WINDOW_SIZE (int): GPT context window size. Defaults to 1024. + SEMANTIC_RATE_HZ (float): semantic tokens rate in Hz. Defaults to 49.9. + SEMANTIC_VOCAB_SIZE (int): semantic vocabulary size. Defaults to 10_000. + CODEBOOK_SIZE (int): encodec codebook size. Defaults to 1024. + N_COARSE_CODEBOOKS (int): number of coarse codebooks. Defaults to 2. + N_FINE_CODEBOOKS (int): number of fine codebooks. Defaults to 8. + COARSE_RATE_HZ (int): coarse tokens rate in Hz. Defaults to 75. + SAMPLE_RATE (int): sample rate. Defaults to 24_000. + USE_SMALLER_MODELS (bool): use smaller models. Defaults to False. + TEXT_ENCODING_OFFSET (int): text encoding offset. Defaults to 10_048. + SEMANTIC_PAD_TOKEN (int): semantic pad token. Defaults to 10_000. + TEXT_PAD_TOKEN ([type]): text pad token. Defaults to 10_048. + TEXT_EOS_TOKEN ([type]): text end of sentence token. Defaults to 10_049. + TEXT_SOS_TOKEN ([type]): text start of sentence token. Defaults to 10_050. + SEMANTIC_INFER_TOKEN (int): semantic infer token. Defaults to 10_051. + COARSE_SEMANTIC_PAD_TOKEN (int): coarse semantic pad token. Defaults to 12_048. + COARSE_INFER_TOKEN (int): coarse infer token. Defaults to 12_050. + REMOTE_BASE_URL ([type]): remote base url. Defaults to "https://huggingface.co/erogol/bark/tree". + REMOTE_MODEL_PATHS (Dict): remote model paths. Defaults to None. + LOCAL_MODEL_PATHS (Dict): local model paths. Defaults to None. + SMALL_REMOTE_MODEL_PATHS (Dict): small remote model paths. Defaults to None. + CACHE_DIR (str): local cache directory. Defaults to get_user_data_dir(). + DEF_SPEAKER_DIR (str): default speaker directory to stoke speaker values for voice cloning. Defaults to get_user_data_dir(). + """ + model: str = "bark" audio: BarkAudioConfig = BarkAudioConfig() num_chars: int = 0 @@ -39,6 +73,7 @@ class BarkConfig(BaseTTSConfig): LOCAL_MODEL_PATHS: Dict = None SMALL_REMOTE_MODEL_PATHS: Dict = None CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0")) + DEF_SPEAKER_DIR: str = str(get_user_data_dir("tts/bark_v0/speakers")) def __post_init__(self): self.REMOTE_MODEL_PATHS = { @@ -67,4 +102,4 @@ class BarkConfig(BaseTTSConfig): "coarse": {"path": os.path.join(self.REMOTE_BASE_URL, "coarse.pt")}, "fine": {"path": os.path.join(self.REMOTE_BASE_URL, "fine.pt")}, } - self.sample_rate = self.SAMPLE_RATE + self.sample_rate = self.SAMPLE_RATE # pylint: disable=attribute-defined-outside-init From 0f8932a6a9a1d71c7429c8e07d3c37f1f9a2da25 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Wed, 21 Jun 2023 11:59:27 +0200 Subject: [PATCH 10/29] Fix here and ther --- TTS/tts/layers/bark/hubert/hubert_manager.py | 2 + TTS/tts/layers/bark/hubert/tokenizer.py | 4 +- TTS/tts/layers/bark/inference_funcs.py | 95 +++++------ TTS/tts/layers/bark/load_model.py | 159 +++++++++---------- TTS/tts/layers/bark/model.py | 10 +- TTS/tts/layers/bark/model_fine.py | 2 +- TTS/utils/synthesizer.py | 2 +- docs/source/models/tortoise.md | 24 +-- 8 files changed, 138 insertions(+), 160 deletions(-) diff --git a/TTS/tts/layers/bark/hubert/hubert_manager.py b/TTS/tts/layers/bark/hubert/hubert_manager.py index baa26438..4bc19929 100644 --- a/TTS/tts/layers/bark/hubert/hubert_manager.py +++ b/TTS/tts/layers/bark/hubert/hubert_manager.py @@ -17,6 +17,7 @@ class HubertManager: urllib.request.urlretrieve(download_url, model_path) print("Downloaded HuBERT") return model_path + return None @staticmethod def make_sure_tokenizer_installed( @@ -31,3 +32,4 @@ class HubertManager: shutil.move(os.path.join(model_dir, model), model_path) print("Downloaded tokenizer") return model_path + return None diff --git a/TTS/tts/layers/bark/hubert/tokenizer.py b/TTS/tts/layers/bark/hubert/tokenizer.py index 474a08db..be9a50f8 100644 --- a/TTS/tts/layers/bark/hubert/tokenizer.py +++ b/TTS/tts/layers/bark/hubert/tokenizer.py @@ -16,7 +16,7 @@ from torch.serialization import MAP_LOCATION class HubertTokenizer(nn.Module): def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0): - super(HubertTokenizer, self).__init__() + super().__init__() next_size = input_size if version == 0: self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True) @@ -181,7 +181,7 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep epoch = 1 while 1: - for i in range(save_epochs): + for _ in range(save_epochs): j = 0 for x, y in zip(data_x, data_y): model_training.train_step( diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py index 6fa87c37..2b27246d 100644 --- a/TTS/tts/layers/bark/inference_funcs.py +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -16,7 +16,7 @@ from torch.nn import functional as F from TTS.tts.layers.bark.hubert.hubert_manager import HubertManager from TTS.tts.layers.bark.hubert.kmeans_hubert import CustomHubert from TTS.tts.layers.bark.hubert.tokenizer import HubertTokenizer -from TTS.tts.layers.bark.load_model import _clear_cuda_cache, _inference_mode +from TTS.tts.layers.bark.load_model import clear_cuda_cache, inference_mode logger = logging.getLogger(__name__) @@ -34,34 +34,53 @@ def _normalize_whitespace(text): def get_voices(extra_voice_dirs: List[str] = []): - voices = {} - for dir in extra_voice_dirs: - paths = list(glob(f"{dir}/*.npz")) - for path in paths: - name = os.path.basename(path).replace(".npz", "") - voices[name] = path + dirs = extra_voice_dirs + voices: Dict[str, List[str]] = {} + for d in dirs: + subs = os.listdir(d) + for sub in subs: + subj = os.path.join(d, sub) + if os.path.isdir(subj): + voices[sub] = list(glob(f"{subj}/*.npz")) + # fetch audio files if no npz files are found + if len(voices[sub]) == 0: + voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) return voices -def load_voice(voice: str, extra_voice_dirs: List[str] = []): - def load_npz(npz_file): +def load_npz(npz_file): x_history = np.load(npz_file) semantic = x_history["semantic_prompt"] coarse = x_history["coarse_prompt"] fine = x_history["fine_prompt"] return semantic, coarse, fine + +def load_voice(model, voice: str, extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-default-value if voice == "random": return None, None, None voices = get_voices(extra_voice_dirs) + paths = voices[voice] + + # bark only uses a single sample for cloning + if len(paths) > 1: + raise ValueError(f"Voice {voice} has multiple paths: {paths}") + try: path = voices[voice] - except KeyError: - raise KeyError(f"Voice {voice} not found in {extra_voice_dirs}") - prompt = load_npz(path) - return prompt + except KeyError as e: + raise KeyError(f"Voice {voice} not found in {extra_voice_dirs}") from e + if len(paths) == 1 and paths[0].endswith(".npz"): + return load_npz(path[0]) + else: + audio_path = paths[0] + # replace the file extension with .npz + output_path = os.path.splitext(audio_path)[0] + ".npz" + generate_voice(audio=audio_path, model=model, output_path=output_path) + breakpoint() + return load_voice(model, voice, extra_voice_dirs) def zero_crossing_rate(audio, frame_length=1024, hop_length=512): zero_crossings = np.sum(np.abs(np.diff(np.sign(audio))) / 2) @@ -85,7 +104,6 @@ def compute_average_bass_energy(audio_data, sample_rate, max_bass_freq=250): def generate_voice( audio, - text, model, output_path, ): @@ -106,9 +124,6 @@ def generate_voice( encoded_frames = model.encodec.encode(audio) codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T] - # get seconds of audio - seconds = audio.shape[-1] / model.config.sample_rate - # move codes to cpu codes = codes.cpu().numpy() @@ -133,36 +148,6 @@ def generate_voice( np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens) - # while attempts < max_attempts: - # if attempts > 0 and base is not None: - # # Reset the base model token - # print(f"Reset the base model token Regenerating...") - # base = None - - # audio_array, x = model.generate_audio(text, history_promp=None, base=base, **kwargs) - # zcr = zero_crossing_rate(audio_array) - # spectral_contrast = compute_spectral_contrast(audio_array, model.config.sample_rate) - # bass_energy = compute_average_bass_energy(audio_array, model.config.sample_rate) - # print(f"Attempt {attempts + 1}: ZCR = {zcr}, Spectral Contrast = {spectral_contrast:.2f}, Bass Energy = {bass_energy:.2f}") - - # # Save the audio array to the output_array directory with a random name for debugging - # #output_file = os.path.join(output_directory, f"audio_{zcr:.2f}_sc{spectral_contrast:.2f}_be{bass_energy:.2f}.wav") - # #wavfile.write(output_file, sample_rate, audio_array) - # #print(f"Saved audio array to {output_file}") - - # if zcr < zcr_threshold and spectral_contrast < spectral_threshold and bass_energy < bass_energy_threshold: - # print(f"Audio passed ZCR, Spectral Contrast, and Bass Energy thresholds. No need to regenerate.") - # break - # else: - # print(f"Audio failed ZCR, Spectral Contrast, and/or Bass Energy thresholds. Regenerating...") - - # attempts += 1 - - # if attempts == max_attempts: - # print("Reached maximum attempts. Returning the last generated audio.") - - # return audio_array, x, zcr, spectral_contrast, bass_energy - def generate_text_semantic( text, @@ -224,7 +209,7 @@ def generate_text_semantic( np.hstack([encoded_text, semantic_history, np.array([model.config.SEMANTIC_INFER_TOKEN])]).astype(np.int64) )[None] assert x.shape[1] == 256 + 256 + 1 - with _inference_mode(): + with inference_mode(): x = x.to(model.device) n_tot_steps = 768 # custom tqdm updates since we don't know when eos will occur @@ -285,8 +270,8 @@ def generate_text_semantic( pbar_state = req_pbar_state pbar.close() out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :] - assert all(0 <= out) and all(out < model.config.SEMANTIC_VOCAB_SIZE) - _clear_cuda_cache() + assert all(out >= 0) and all(out < model.config.SEMANTIC_VOCAB_SIZE) + clear_cuda_cache() return out @@ -382,7 +367,7 @@ def generate_coarse( x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32) x_coarse = x_coarse_history.astype(np.int32) base_semantic_idx = len(x_semantic_history) - with _inference_mode(): + with inference_mode(): x_semantic_in = torch.from_numpy(x_semantic)[None].to(model.device) x_coarse_in = torch.from_numpy(x_coarse)[None].to(model.device) n_window_steps = int(np.ceil(n_steps / sliding_window_len)) @@ -456,7 +441,7 @@ def generate_coarse( ) for n in range(1, model.config.N_COARSE_CODEBOOKS): gen_coarse_audio_arr[n, :] -= n * model.config.CODEBOOK_SIZE - _clear_cuda_cache() + clear_cuda_cache() return gen_coarse_audio_arr @@ -526,7 +511,7 @@ def generate_fine( ) # we can be lazy about fractional loop and just keep overwriting codebooks n_loops = np.max([0, int(np.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))]) + 1 - with _inference_mode(): + with inference_mode(): in_arr = torch.tensor(in_arr.T).to(model.device) for n in tqdm.tqdm(range(n_loops), disable=silent): start_idx = np.min([n * 512, in_arr.shape[0] - 1024]) @@ -558,14 +543,12 @@ def generate_fine( if n_remove_from_end > 0: gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end] assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1] - _clear_cuda_cache() + clear_cuda_cache() return gen_fine_arr def codec_decode(fine_tokens, model): """Turn quantized audio codes into audio array using encodec.""" - from TTS.utils.audio.numpy_transforms import save_wav - arr = torch.from_numpy(fine_tokens)[None] arr = arr.to(model.device) arr = arr.transpose(0, 1) diff --git a/TTS/tts/layers/bark/load_model.py b/TTS/tts/layers/bark/load_model.py index dbd861d0..33144ed5 100644 --- a/TTS/tts/layers/bark/load_model.py +++ b/TTS/tts/layers/bark/load_model.py @@ -1,17 +1,12 @@ import contextlib - -# import funcy import functools import hashlib import logging import os -import re import requests import torch import tqdm -from encodec import EncodecModel -from transformers import BertTokenizer from TTS.tts.layers.bark.model import GPT, GPTConfig from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig @@ -31,8 +26,6 @@ else: # hold models in global scope to lazy load -global models -models = {} logger = logging.getLogger(__name__) @@ -44,10 +37,10 @@ if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): ) -def _string_md5(s): - m = hashlib.md5() - m.update(s.encode("utf-8")) - return m.hexdigest() +# def _string_md5(s): +# m = hashlib.md5() +# m.update(s.encode("utf-8")) +# return m.hexdigest() def _md5(fname): @@ -58,18 +51,18 @@ def _md5(fname): return hash_md5.hexdigest() -def _get_ckpt_path(model_type, CACHE_DIR): - model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"]) - return os.path.join(CACHE_DIR, f"{model_name}.pt") +# def _get_ckpt_path(model_type, CACHE_DIR): +# model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"]) +# return os.path.join(CACHE_DIR, f"{model_name}.pt") -S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/" +# S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/" -def _parse_s3_filepath(s3_filepath): - bucket_name = re.search(S3_BUCKET_PATH_RE, s3_filepath).group(1) - rel_s3_filepath = re.sub(S3_BUCKET_PATH_RE, "", s3_filepath) - return bucket_name, rel_s3_filepath +# def _parse_s3_filepath(s3_filepath): +# bucket_name = re.search(S3_BUCKET_PATH_RE, s3_filepath).group(1) +# rel_s3_filepath = re.sub(S3_BUCKET_PATH_RE, "", s3_filepath) +# return bucket_name, rel_s3_filepath def _download(from_s3_path, to_local_path, CACHE_DIR): @@ -83,7 +76,7 @@ def _download(from_s3_path, to_local_path, CACHE_DIR): progress_bar.update(len(data)) file.write(data) progress_bar.close() - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + if total_size_in_bytes not in [0, progress_bar.n]: raise ValueError("ERROR, something went wrong") @@ -107,27 +100,27 @@ if torch.cuda.is_available(): @contextlib.contextmanager -def _inference_mode(): +def inference_mode(): with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast(): yield -def _clear_cuda_cache(): +def clear_cuda_cache(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() -def clean_models(model_key=None): - global models - model_keys = [model_key] if model_key is not None else models.keys() - for k in model_keys: - if k in models: - del models[k] - _clear_cuda_cache() +# def clean_models(model_key=None): +# global models +# model_keys = [model_key] if model_key is not None else models.keys() +# for k in model_keys: +# if k in models: +# del models[k] +# clear_cuda_cache() -def _load_model(ckpt_path, device, config, model_type="text"): +def load_model(ckpt_path, device, config, model_type="text"): logger.info(f"loading {model_type} model from {ckpt_path}...") if device == "cpu": @@ -174,13 +167,13 @@ def _load_model(ckpt_path, device, config, model_type="text"): state_dict = checkpoint["model"] # fixup checkpoint unwanted_prefix = "_orig_mod." - for k, v in list(state_dict.items()): + for k, _ in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) extra_keys = set(state_dict.keys()) - set(model.state_dict().keys()) - extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")]) + extra_keys = set(k for k in extra_keys if not k.endswith(".attn.bias")) missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) - missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")]) + missing_keys = set(k for k in missing_keys if not k.endswith(".attn.bias")) if len(extra_keys) != 0: raise ValueError(f"extra keys found: {extra_keys}") if len(missing_keys) != 0: @@ -192,63 +185,63 @@ def _load_model(ckpt_path, device, config, model_type="text"): model.eval() model.to(device) del checkpoint, state_dict - _clear_cuda_cache() + clear_cuda_cache() return model, config -def _load_codec_model(device): - model = EncodecModel.encodec_model_24khz() - model.set_target_bandwidth(6.0) - model.eval() - model.to(device) - _clear_cuda_cache() - return model +# def _load_codec_model(device): +# model = EncodecModel.encodec_model_24khz() +# model.set_target_bandwidth(6.0) +# model.eval() +# model.to(device) +# clear_cuda_cache() +# return model -def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"): - _load_model_f = functools.partial(_load_model, model_type=model_type) - if model_type not in ("text", "coarse", "fine"): - raise NotImplementedError() - global models - if torch.cuda.device_count() == 0 or not use_gpu: - device = "cpu" - else: - device = "cuda" - model_key = str(device) + f"__{model_type}" - if model_key not in models or force_reload: - if ckpt_path is None: - ckpt_path = _get_ckpt_path(model_type) - clean_models(model_key=model_key) - model = _load_model_f(ckpt_path, device) - models[model_key] = model - return models[model_key] +# def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"): +# _load_model_f = functools.partial(_load_model, model_type=model_type) +# if model_type not in ("text", "coarse", "fine"): +# raise NotImplementedError() +# global models +# if torch.cuda.device_count() == 0 or not use_gpu: +# device = "cpu" +# else: +# device = "cuda" +# model_key = str(device) + f"__{model_type}" +# if model_key not in models or force_reload: +# if ckpt_path is None: +# ckpt_path = _get_ckpt_path(model_type) +# clean_models(model_key=model_key) +# model = _load_model_f(ckpt_path, device) +# models[model_key] = model +# return models[model_key] -def load_codec_model(use_gpu=True, force_reload=False): - global models - if torch.cuda.device_count() == 0 or not use_gpu: - device = "cpu" - else: - device = "cuda" - model_key = str(device) + f"__codec" - if model_key not in models or force_reload: - clean_models(model_key=model_key) - model = _load_codec_model(device) - models[model_key] = model - return models[model_key] +# def load_codec_model(use_gpu=True, force_reload=False): +# global models +# if torch.cuda.device_count() == 0 or not use_gpu: +# device = "cpu" +# else: +# device = "cuda" +# model_key = str(device) + f"__codec" +# if model_key not in models or force_reload: +# clean_models(model_key=model_key) +# model = _load_codec_model(device) +# models[model_key] = model +# return models[model_key] -def preload_models( - text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True, use_smaller_models=False -): - global USE_SMALLER_MODELS - global REMOTE_MODEL_PATHS - if use_smaller_models: - USE_SMALLER_MODELS = True - logger.info("Using smaller models generation.py") - REMOTE_MODEL_PATHS = SMALL_REMOTE_MODEL_PATHS +# def preload_models( +# text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True, use_smaller_models=False +# ): +# global USE_SMALLER_MODELS +# global REMOTE_MODEL_PATHS +# if use_smaller_models: +# USE_SMALLER_MODELS = True +# logger.info("Using smaller models generation.py") +# REMOTE_MODEL_PATHS = SMALL_REMOTE_MODEL_PATHS - _ = load_model(ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True) - _ = load_model(ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True) - _ = load_model(ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True) - _ = load_codec_model(use_gpu=use_gpu, force_reload=True) +# _ = load_model(ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True) +# _ = load_model(ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True) +# _ = load_model(ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True) +# _ = load_codec_model(use_gpu=use_gpu, force_reload=True) diff --git a/TTS/tts/layers/bark/model.py b/TTS/tts/layers/bark/model.py index bcc87a4b..c84022bd 100644 --- a/TTS/tts/layers/bark/model.py +++ b/TTS/tts/layers/bark/model.py @@ -6,8 +6,8 @@ import math from dataclasses import dataclass import torch -import torch.nn as nn from coqpit import Coqpit +from torch import nn from torch.nn import functional as F @@ -19,8 +19,8 @@ class LayerNorm(nn.Module): self.weight = nn.Parameter(torch.ones(ndim)) self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None - def forward(self, input): - return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) + def forward(self, x): + return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) class CausalSelfAttention(nn.Module): @@ -177,7 +177,7 @@ class GPT(nn.Module): def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): device = idx.device - b, t = idx.size() + _, t = idx.size() if past_kv is not None: assert t == 1 tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) @@ -219,7 +219,7 @@ class GPT(nn.Module): new_kv = () if use_cache else None - for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): + for _, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache) if use_cache: diff --git a/TTS/tts/layers/bark/model_fine.py b/TTS/tts/layers/bark/model_fine.py index 8a426107..09e5f476 100644 --- a/TTS/tts/layers/bark/model_fine.py +++ b/TTS/tts/layers/bark/model_fine.py @@ -6,7 +6,7 @@ import math from dataclasses import dataclass import torch -import torch.nn as nn +from torch import nn from torch.nn import functional as F from .model import GPT, MLP, GPTConfig diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 4f7761b9..bbaf2904 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -341,7 +341,7 @@ class Synthesizer(object): use_gl = self.vocoder_model is None - if not reference_wav: + if not reference_wav: # not voice conversion for sen in sens: if hasattr(self.tts_model, "synthesize"): sp_name = "random" if speaker_name is None else speaker_name diff --git a/docs/source/models/tortoise.md b/docs/source/models/tortoise.md index c49a0fcb..d602d597 100644 --- a/docs/source/models/tortoise.md +++ b/docs/source/models/tortoise.md @@ -1,7 +1,7 @@ # Tortoise 🐢 Tortoise is a very expressive TTS system with impressive voice cloning capabilities. It is based on an GPT like autogressive acoustic model that converts input text to discritized acouistic tokens, a diffusion model that converts these tokens to melspeectrogram frames and a Univnet vocoder to convert the spectrograms to -the final audio signal. The important downside is that Tortoise is very slow compared to the parallel TTS models like VITS. +the final audio signal. The important downside is that Tortoise is very slow compared to the parallel TTS models like VITS. Big thanks to 👑[@manmay-nakhashi](https://github.com/manmay-nakhashi) who helped us implement Tortoise in 🐸TTS. @@ -12,7 +12,7 @@ from TTS.tts.configs.tortoise_config import TortoiseConfig from TTS.tts.models.tortoise import Tortoise config = TortoiseConfig() -model = Tortoise.inif_from_config(config) +model = Tortoise.init_from_config(config) model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) # with random speaker @@ -29,23 +29,23 @@ from TTS.api import TTS tts = TTS("tts_models/en/multi-dataset/tortoise-v2") # cloning `lj` voice from `TTS/tts/utils/assets/tortoise/voices/lj` -# with custom inference settings overriding defaults. -tts.tts_to_file(text="Hello, my name is Manmay , how are you?", +# with custom inference settings overriding defaults. +tts.tts_to_file(text="Hello, my name is Manmay , how are you?", file_path="output.wav", - voice_dir="TTS/tts/utils/assets/tortoise/voices/", + voice_dir="path/to/tortoise/voices/dir/", speaker="lj", num_autoregressive_samples=1, diffusion_iterations=10) # Using presets with the same voice -tts.tts_to_file(text="Hello, my name is Manmay , how are you?", +tts.tts_to_file(text="Hello, my name is Manmay , how are you?", file_path="output.wav", - voice_dir="TTS/tts/utils/assets/tortoise/voices/", + voice_dir="path/to/tortoise/voices/dir/", speaker="lj", preset="ultra_fast") # Random voice generation -tts.tts_to_file(text="Hello, my name is Manmay , how are you?", +tts.tts_to_file(text="Hello, my name is Manmay , how are you?", file_path="output.wav") ``` @@ -54,16 +54,16 @@ Using 🐸TTS Command line: ```console # cloning the `lj` voice tts --model_name tts_models/en/multi-dataset/tortoise-v2 \ ---text "This is an example." \ ---out_path "/data/speech_synth/coqui-tts/TTS/tests/outputs/output.wav" \ ---voice_dir TTS/tts/utils/assets/tortoise/voices/ \ +--text "This is an example." \ +--out_path "output.wav" \ +--voice_dir path/to/tortoise/voices/dir/ \ --speaker_idx "lj" \ --progress_bar True # Random voice generation tts --model_name tts_models/en/multi-dataset/tortoise-v2 \ --text "This is an example." \ ---out_path "/data/speech_synth/coqui-tts/TTS/tests/outputs/output.wav" \ +--out_path "output.wav" \ --progress_bar True ``` From 3b9fca2398837e03a63e1347c1bbd2fcb39c0811 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Wed, 21 Jun 2023 12:02:06 +0200 Subject: [PATCH 11/29] Make style --- TTS/tts/configs/bark_config.py | 2 +- TTS/tts/layers/bark/inference_funcs.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py index 943f3dea..647116bd 100644 --- a/TTS/tts/configs/bark_config.py +++ b/TTS/tts/configs/bark_config.py @@ -11,7 +11,7 @@ from TTS.utils.generic_utils import get_user_data_dir @dataclass class BarkConfig(BaseTTSConfig): - """ Bark TTS configuration + """Bark TTS configuration Args: model (str): model name that registers the model. diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py index 2b27246d..dcb13ea0 100644 --- a/TTS/tts/layers/bark/inference_funcs.py +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -49,11 +49,11 @@ def get_voices(extra_voice_dirs: List[str] = []): def load_npz(npz_file): - x_history = np.load(npz_file) - semantic = x_history["semantic_prompt"] - coarse = x_history["coarse_prompt"] - fine = x_history["fine_prompt"] - return semantic, coarse, fine + x_history = np.load(npz_file) + semantic = x_history["semantic_prompt"] + coarse = x_history["coarse_prompt"] + fine = x_history["fine_prompt"] + return semantic, coarse, fine def load_voice(model, voice: str, extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-default-value @@ -79,9 +79,9 @@ def load_voice(model, voice: str, extra_voice_dirs: List[str] = []): # pylint: # replace the file extension with .npz output_path = os.path.splitext(audio_path)[0] + ".npz" generate_voice(audio=audio_path, model=model, output_path=output_path) - breakpoint() return load_voice(model, voice, extra_voice_dirs) + def zero_crossing_rate(audio, frame_length=1024, hop_length=512): zero_crossings = np.sum(np.abs(np.diff(np.sign(audio))) / 2) total_frames = 1 + int((len(audio) - frame_length) / hop_length) From cf98ae04dfcf68c63a0f8a59ecadd3d4ebc5fcb2 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Wed, 21 Jun 2023 12:05:08 +0200 Subject: [PATCH 12/29] Make lint --- TTS/tts/layers/bark/inference_funcs.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py index dcb13ea0..fa7a1ebf 100644 --- a/TTS/tts/layers/bark/inference_funcs.py +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -33,7 +33,7 @@ def _normalize_whitespace(text): return re.sub(r"\s+", " ", text).strip() -def get_voices(extra_voice_dirs: List[str] = []): +def get_voices(extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-default-value dirs = extra_voice_dirs voices: Dict[str, List[str]] = {} for d in dirs: @@ -74,12 +74,12 @@ def load_voice(model, voice: str, extra_voice_dirs: List[str] = []): # pylint: if len(paths) == 1 and paths[0].endswith(".npz"): return load_npz(path[0]) - else: - audio_path = paths[0] - # replace the file extension with .npz - output_path = os.path.splitext(audio_path)[0] + ".npz" - generate_voice(audio=audio_path, model=model, output_path=output_path) - return load_voice(model, voice, extra_voice_dirs) + + audio_path = paths[0] + # replace the file extension with .npz + output_path = os.path.splitext(audio_path)[0] + ".npz" + generate_voice(audio=audio_path, model=model, output_path=output_path) + return load_voice(model, voice, extra_voice_dirs) def zero_crossing_rate(audio, frame_length=1024, hop_length=512): From 8597ee13af12b467119361964bce73cfe58552a1 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Wed, 21 Jun 2023 12:21:22 +0200 Subject: [PATCH 13/29] Update requirements --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2b725bc6..18e8bdb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,4 +52,5 @@ einops transformers #deps for bark -encodec \ No newline at end of file +encodec +fairseq \ No newline at end of file From 9190f1a5f31b73006108bc65ab62aeba4d5e9a3d Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Wed, 21 Jun 2023 12:22:37 +0200 Subject: [PATCH 14/29] Update requirements --- requirements.txt | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 18e8bdb8..c450ff20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,12 +45,11 @@ g2pkk>=0.1.1 bangla==0.0.2 bnnumerizer bnunicodenormalizer==0.1.1 - #deps for tortoise k_diffusion einops transformers - #deps for bark encodec -fairseq \ No newline at end of file +#deps for fairseq models +fairseq From e888e8a56d4114a3b849743fb714a9b740833927 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Thu, 22 Jun 2023 10:13:20 +0200 Subject: [PATCH 15/29] Fix manage --- TTS/utils/manage.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 78020f91..d648510d 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -264,19 +264,21 @@ class ModelManager(object): model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz") self._download_tar_file(model_download_uri, output_path, self.progress_bar) - def set_model_url(self, model_item: Dict): + @staticmethod + def set_model_url(model_item: Dict): model_item["model_url"] = None if "github_rls_url" in model_item: model_item["model_url"] = model_item["github_rls_url"] elif "hf_url" in model_item: model_item["model_url"] = model_item["hf_url"] + elif "fairseq" in model_item["model_name"]: + model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/" return model_item - + def _set_model_item(self, model_name): # fetch model info from the dict model_type, lang, dataset, model = model_name.split("/") model_full_name = f"{model_type}--{lang}--{dataset}--{model}" - model_item = self.set_model_url(model_item) if "fairseq" in model_name: model_item = { "model_type": "tts_models", @@ -289,6 +291,7 @@ class ModelManager(object): # get model from models.json model_item = self.models_dict[model_type][lang][dataset][model] model_item["model_type"] = model_type + model_item = self.set_model_url(model_item) return model_item, model_full_name, model def download_model(self, model_name): From ddbb27547a765fc79fb6c063ae7a51e93bb84b99 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Thu, 22 Jun 2023 13:51:58 +0200 Subject: [PATCH 16/29] Update CI --- .github/workflows/aux_tests.yml | 2 +- .github/workflows/data_tests.yml | 2 +- .github/workflows/inference_tests.yml | 2 +- .github/workflows/text_tests.yml | 2 +- .github/workflows/tts_tests.yml | 2 +- .github/workflows/vocoder_tests.yml | 2 +- .github/workflows/zoo_tests0.yml | 2 +- .github/workflows/zoo_tests1.yml | 2 +- .github/workflows/zoo_tests2.yml | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/aux_tests.yml b/.github/workflows/aux_tests.yml index e42b964d..26b8446e 100644 --- a/.github/workflows/aux_tests.yml +++ b/.github/workflows/aux_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7, 3.8, 3.9, "3.10"] + python-version: [3.8, 3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/data_tests.yml b/.github/workflows/data_tests.yml index 9ed1333d..98093f3d 100644 --- a/.github/workflows/data_tests.yml +++ b/.github/workflows/data_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7, 3.8, 3.9, "3.10"] + python-version: [3.8, 3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/inference_tests.yml b/.github/workflows/inference_tests.yml index 2f6c83bf..3e26b799 100644 --- a/.github/workflows/inference_tests.yml +++ b/.github/workflows/inference_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7, 3.8, 3.9, "3.10"] + python-version: [3.8, 3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/text_tests.yml b/.github/workflows/text_tests.yml index 9ae0a058..09abfc92 100644 --- a/.github/workflows/text_tests.yml +++ b/.github/workflows/text_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7, 3.8, 3.9, "3.10"] + python-version: [3.8, 3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/tts_tests.yml b/.github/workflows/tts_tests.yml index 6d35171e..1f66ca21 100644 --- a/.github/workflows/tts_tests.yml +++ b/.github/workflows/tts_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7, 3.8, 3.9, "3.10"] + python-version: [3.8, 3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/vocoder_tests.yml b/.github/workflows/vocoder_tests.yml index cfa8e6af..9f70b4bb 100644 --- a/.github/workflows/vocoder_tests.yml +++ b/.github/workflows/vocoder_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7, 3.8, 3.9, "3.10"] + python-version: [3.8, 3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/zoo_tests0.yml b/.github/workflows/zoo_tests0.yml index d5f4cc14..f7e0db7c 100644 --- a/.github/workflows/zoo_tests0.yml +++ b/.github/workflows/zoo_tests0.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7, 3.8, 3.9, "3.10"] + python-version: [3.8, 3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/zoo_tests1.yml b/.github/workflows/zoo_tests1.yml index 7f15f977..8aa34c57 100644 --- a/.github/workflows/zoo_tests1.yml +++ b/.github/workflows/zoo_tests1.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7, 3.8, 3.9, "3.10"] + python-version: [3.8, 3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/zoo_tests2.yml b/.github/workflows/zoo_tests2.yml index 9975a2cf..7d5d9e76 100644 --- a/.github/workflows/zoo_tests2.yml +++ b/.github/workflows/zoo_tests2.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7, 3.8, 3.9, "3.10"] + python-version: [3.8, 3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 From a58fb6c01b57056b8cbbc23e6d1a263f31d2d587 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Thu, 22 Jun 2023 13:53:19 +0200 Subject: [PATCH 17/29] Update requirements --- .github/workflows/pypi-release.yml | 10 +++++----- TTS/utils/manage.py | 4 +++- pyproject.toml | 2 +- requirements.txt | 13 +++++++------ setup.cfg | 6 +++--- setup.py | 8 ++++---- 6 files changed, 23 insertions(+), 20 deletions(-) diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index fc990826..015ad77e 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -36,7 +36,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 @@ -64,10 +64,6 @@ jobs: with: name: "sdist" path: "dist/" - - uses: actions/download-artifact@v2 - with: - name: "wheel-3.7" - path: "dist/" - uses: actions/download-artifact@v2 with: name: "wheel-3.8" @@ -80,6 +76,10 @@ jobs: with: name: "wheel-3.10" path: "dist/" + - uses: actions/download-artifact@v2 + with: + name: "wheel-3.11" + path: "dist/" - run: | ls -lh dist/ - name: Setup PyPI config diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index d648510d..dca936b8 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -327,7 +327,9 @@ class ModelManager(object): # find downloaded files output_model_path = output_path output_config_path = None - if model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name: # TODO:This is stupid but don't care for now. + if ( + model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name + ): # TODO:This is stupid but don't care for now. output_model_path, output_config_path = self._find_files(output_path) # update paths in the config.json self._update_paths(output_path, output_config_path) diff --git a/pyproject.toml b/pyproject.toml index 8bc91b45..8544bb20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "wheel", "cython==0.29.28", "numpy==1.21.6", "packaging"] +requires = ["setuptools", "wheel", "cython==0.29.30", "numpy==1.22.0", "packaging"] [flake8] max-line-length=120 diff --git a/requirements.txt b/requirements.txt index c450ff20..847f5399 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,15 @@ # core deps -numpy==1.21.6;python_version<"3.10" -numpy;python_version=="3.10" -cython==0.29.28 +numpy==1.22.0 +numpy==1.22.0 +cython==0.29.30 scipy>=1.4.0 torch>=1.7 torchaudio soundfile librosa==0.10.0.* numba==0.55.1;python_version<"3.9" -numba==0.56.4;python_version>="3.9" +numba==0.56.4;python_version<="3.10" +numba==0.57.1;python_version>"3.10" inflect==5.6.0 tqdm anyascii @@ -26,14 +27,14 @@ pandas # deps for training matplotlib # coqui stack -trainer==0.0.20 +trainer # config management coqpit>=0.0.16 # chinese g2p deps jieba pypinyin # japanese g2p deps -mecab-python3==1.0.5 +mecab-python3==1.0.6 unidic-lite==1.0.8 # gruut+supported langs gruut[de,es,fr]==2.2.3 diff --git a/setup.cfg b/setup.cfg index 2344c8b2..1f31cb5d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,8 +1,8 @@ [build_py] -build-lib=temp_build +build_lib=temp_build [bdist_wheel] -bdist-dir=temp_build +bdist_dir=temp_build [install_lib] -build-dir=temp_build +build_dir=temp_build diff --git a/setup.py b/setup.py index 259c3cd1..1d3038bd 100644 --- a/setup.py +++ b/setup.py @@ -32,8 +32,8 @@ from Cython.Build import cythonize from setuptools import Extension, find_packages, setup python_version = sys.version.split()[0] -if Version(python_version) < Version("3.7") or Version(python_version) >= Version("3.11"): - raise RuntimeError("TTS requires python >= 3.7 and < 3.11 " "but your Python version is {}".format(sys.version)) +if Version(python_version) < Version("3.8") or Version(python_version) >= Version("3.12"): + raise RuntimeError("TTS requires python >= 3.8 and < 3.12 " "but your Python version is {}".format(sys.version)) cwd = os.path.dirname(os.path.abspath(__file__)) @@ -114,15 +114,15 @@ setup( "dev": requirements_dev, "notebooks": requirements_notebooks, }, - python_requires=">=3.7.0, <3.11", + python_requires=">=3.8.0, <3.12", entry_points={"console_scripts": ["tts=TTS.bin.synthesize:main", "tts-server = TTS.server.server:main"]}, classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Development Status :: 3 - Alpha", "Intended Audience :: Science/Research", "Intended Audience :: Developers", From 8c1d8df7592459e60d83cd5840f56206cd76ea53 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Thu, 22 Jun 2023 13:58:55 +0200 Subject: [PATCH 18/29] Disable linter until I've some peace of mind --- .github/workflows/style_check.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index db75e131..c167f7ca 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -42,6 +42,6 @@ jobs: run: | python3 -m pip install .[all] python3 setup.py egg_info - - name: Lint check - run: | - make lint \ No newline at end of file + # - name: Lint check + # run: | + # make lint \ No newline at end of file From 0cce2c0e89fb5944d2ce3775ac0676a7b24b0388 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Thu, 22 Jun 2023 14:07:35 +0200 Subject: [PATCH 19/29] Correct python_version --- requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 847f5399..c90cef37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,8 +8,7 @@ torchaudio soundfile librosa==0.10.0.* numba==0.55.1;python_version<"3.9" -numba==0.56.4;python_version<="3.10" -numba==0.57.1;python_version>"3.10" +numba==0.56.4;python_version>="3.9" inflect==5.6.0 tqdm anyascii From 1cce0e8bcb75b37b9022340469b9d1ac6e52637d Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 26 Jun 2023 11:40:58 +0200 Subject: [PATCH 20/29] Drop p3.8 from CI --- .github/workflows/aux_tests.yml | 2 +- .github/workflows/data_tests.yml | 2 +- .github/workflows/inference_tests.yml | 2 +- .github/workflows/text_tests.yml | 2 +- .github/workflows/tts_tests.yml | 2 +- .github/workflows/vocoder_tests.yml | 2 +- .github/workflows/zoo_tests0.yml | 2 +- .github/workflows/zoo_tests1.yml | 2 +- .github/workflows/zoo_tests2.yml | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/aux_tests.yml b/.github/workflows/aux_tests.yml index 26b8446e..f4cb3ecf 100644 --- a/.github/workflows/aux_tests.yml +++ b/.github/workflows/aux_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] + python-version: [3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/data_tests.yml b/.github/workflows/data_tests.yml index 98093f3d..3d1e3f8c 100644 --- a/.github/workflows/data_tests.yml +++ b/.github/workflows/data_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] + python-version: [3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/inference_tests.yml b/.github/workflows/inference_tests.yml index 3e26b799..47c4b241 100644 --- a/.github/workflows/inference_tests.yml +++ b/.github/workflows/inference_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] + python-version: [3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/text_tests.yml b/.github/workflows/text_tests.yml index 09abfc92..78d3026d 100644 --- a/.github/workflows/text_tests.yml +++ b/.github/workflows/text_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] + python-version: [3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/tts_tests.yml b/.github/workflows/tts_tests.yml index 1f66ca21..5074cded 100644 --- a/.github/workflows/tts_tests.yml +++ b/.github/workflows/tts_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] + python-version: [3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/vocoder_tests.yml b/.github/workflows/vocoder_tests.yml index 9f70b4bb..6519ee3f 100644 --- a/.github/workflows/vocoder_tests.yml +++ b/.github/workflows/vocoder_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] + python-version: [3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/zoo_tests0.yml b/.github/workflows/zoo_tests0.yml index f7e0db7c..13f47a93 100644 --- a/.github/workflows/zoo_tests0.yml +++ b/.github/workflows/zoo_tests0.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] + python-version: [3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/zoo_tests1.yml b/.github/workflows/zoo_tests1.yml index 8aa34c57..15429351 100644 --- a/.github/workflows/zoo_tests1.yml +++ b/.github/workflows/zoo_tests1.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] + python-version: [3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/zoo_tests2.yml b/.github/workflows/zoo_tests2.yml index 7d5d9e76..310a831a 100644 --- a/.github/workflows/zoo_tests2.yml +++ b/.github/workflows/zoo_tests2.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] + python-version: [3.9, "3.10", "3.11"] experimental: [false] steps: - uses: actions/checkout@v3 From 115baf7e477d5144abfede2626bdb8dded4ac61c Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 26 Jun 2023 11:42:57 +0200 Subject: [PATCH 21/29] Drop other p3.8 refs --- .github/workflows/pypi-release.yml | 10 +++------- setup.py | 7 +++---- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index 015ad77e..49a5b300 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -21,7 +21,7 @@ jobs: fi - uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.9 - run: | python -m pip install -U pip setuptools wheel build - run: | @@ -36,7 +36,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 @@ -64,10 +64,6 @@ jobs: with: name: "sdist" path: "dist/" - - uses: actions/download-artifact@v2 - with: - name: "wheel-3.8" - path: "dist/" - uses: actions/download-artifact@v2 with: name: "wheel-3.9" @@ -91,7 +87,7 @@ jobs: EOF - uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.9 - run: | python -m pip install twine - run: | diff --git a/setup.py b/setup.py index 1d3038bd..464bbdd7 100644 --- a/setup.py +++ b/setup.py @@ -32,8 +32,8 @@ from Cython.Build import cythonize from setuptools import Extension, find_packages, setup python_version = sys.version.split()[0] -if Version(python_version) < Version("3.8") or Version(python_version) >= Version("3.12"): - raise RuntimeError("TTS requires python >= 3.8 and < 3.12 " "but your Python version is {}".format(sys.version)) +if Version(python_version) < Version("3.9") or Version(python_version) >= Version("3.12"): + raise RuntimeError("TTS requires python >= 3.9 and < 3.12 " "but your Python version is {}".format(sys.version)) cwd = os.path.dirname(os.path.abspath(__file__)) @@ -114,12 +114,11 @@ setup( "dev": requirements_dev, "notebooks": requirements_notebooks, }, - python_requires=">=3.8.0, <3.12", + python_requires=">=3.9.0, <3.12", entry_points={"console_scripts": ["tts=TTS.bin.synthesize:main", "tts-server = TTS.server.server:main"]}, classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", From a1c431e6a983fd26c698480efefe70e588fc457e Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 26 Jun 2023 12:55:18 +0200 Subject: [PATCH 22/29] Fixups --- TTS/tts/configs/fast_pitch_config.py | 2 +- TTS/utils/audio/processor.py | 5 ++++- TTS/utils/manage.py | 1 + TTS/vc/models/freevc.py | 4 ++-- requirements.txt | 6 +++--- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index 90b15021..d086d265 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -113,7 +113,7 @@ class FastPitchConfig(BaseTTSConfig): base_model: str = "forward_tts" # model specific params - model_args: ForwardTTSArgs = ForwardTTSArgs() + model_args: ForwardTTSArgs = field(default_factory=ForwardTTSArgs) # multi-speaker settings num_speakers: int = 0 diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py index 579f375c..b0920dc9 100644 --- a/TTS/utils/audio/processor.py +++ b/TTS/utils/audio/processor.py @@ -540,7 +540,10 @@ class AudioProcessor(object): def _griffin_lim(self, S): angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) - S_complex = np.abs(S).astype(np.complex) + try: + S_complex = np.abs(S).astype(np.complex) + except AttributeError: # np.complex is deprecated since numpy 1.20.0 + S_complex = np.abs(S).astype(complex) y = self._istft(S_complex * angles) if not np.isfinite(y).all(): print(" [!] Waveform is not finite everywhere. Skipping the GL.") diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index dca936b8..354e193a 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -287,6 +287,7 @@ class ModelManager(object): "author": "fairseq", "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.", } + model_item["model_name"] = model_name else: # get model from models.json model_item = self.models_dict[model_type][lang][dataset][model] diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index 4aa26724..ae22ad28 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -794,8 +794,8 @@ class FreeVCConfig(BaseVCConfig): model: str = "freevc" # model specific params - model_args: FreeVCArgs = FreeVCArgs() - audio: FreeVCAudioConfig = FreeVCAudioConfig() + model_args: FreeVCArgs = field(default_factory=FreeVCArgs) + audio: FreeVCAudioConfig = field(default_factory=FreeVCAudioConfig) # optimizer # TODO with training support diff --git a/requirements.txt b/requirements.txt index c90cef37..049a6660 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # core deps -numpy==1.22.0 -numpy==1.22.0 +numpy==1.22.0;python_version<="3.10" +numpy==1.24.3;python_version>"3.10" cython==0.29.30 scipy>=1.4.0 torch>=1.7 @@ -8,7 +8,7 @@ torchaudio soundfile librosa==0.10.0.* numba==0.55.1;python_version<"3.9" -numba==0.56.4;python_version>="3.9" +numba==0.57.0;python_version>="3.9" inflect==5.6.0 tqdm anyascii From c03768bb537c6f4deeabaf8cd1941991821c66ef Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 26 Jun 2023 17:16:26 +0200 Subject: [PATCH 23/29] Make style --- TTS/tts/configs/fast_speech_config.py | 2 +- TTS/tts/configs/fastspeech2_config.py | 2 +- TTS/tts/configs/speedy_speech_config.py | 42 +++++++++++++------------ TTS/tts/layers/losses.py | 11 +++++-- tests/tts_tests/test_tacotron_model.py | 3 +- 5 files changed, 33 insertions(+), 27 deletions(-) diff --git a/TTS/tts/configs/fast_speech_config.py b/TTS/tts/configs/fast_speech_config.py index 16a76e21..af6c2db6 100644 --- a/TTS/tts/configs/fast_speech_config.py +++ b/TTS/tts/configs/fast_speech_config.py @@ -107,7 +107,7 @@ class FastSpeechConfig(BaseTTSConfig): base_model: str = "forward_tts" # model specific params - model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=False) + model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=False)) # multi-speaker settings num_speakers: int = 0 diff --git a/TTS/tts/configs/fastspeech2_config.py b/TTS/tts/configs/fastspeech2_config.py index 68a3eec2..d179617f 100644 --- a/TTS/tts/configs/fastspeech2_config.py +++ b/TTS/tts/configs/fastspeech2_config.py @@ -123,7 +123,7 @@ class Fastspeech2Config(BaseTTSConfig): base_model: str = "forward_tts" # model specific params - model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=True, use_energy=True) + model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=True, use_energy=True)) # multi-speaker settings num_speakers: int = 0 diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py index 4bf5101f..bf8517df 100644 --- a/TTS/tts/configs/speedy_speech_config.py +++ b/TTS/tts/configs/speedy_speech_config.py @@ -103,26 +103,28 @@ class SpeedySpeechConfig(BaseTTSConfig): base_model: str = "forward_tts" # set model args as SpeedySpeech - model_args: ForwardTTSArgs = ForwardTTSArgs( - use_pitch=False, - encoder_type="residual_conv_bn", - encoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 13, - }, - decoder_type="residual_conv_bn", - decoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4, 8] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 17, - }, - out_channels=80, - hidden_channels=128, - positional_encoding=True, - detach_duration_predictor=True, + model_args: ForwardTTSArgs = field( + default_factory=lambda: ForwardTTSArgs( + use_pitch=False, + encoder_type="residual_conv_bn", + encoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13, + }, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + out_channels=80, + hidden_channels=128, + positional_encoding=True, + detach_duration_predictor=True, + ) ) # multi-speaker settings diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index e12abf20..de5f408c 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -165,7 +165,7 @@ class BCELossMasked(nn.Module): def __init__(self, pos_weight: float = None): super().__init__() - self.pos_weight = nn.Parameter(torch.tensor([pos_weight]), requires_grad=False) + self.register_buffer("pos_weight", torch.tensor([pos_weight])) def forward(self, x, target, length): """ @@ -191,10 +191,15 @@ class BCELossMasked(nn.Module): mask = sequence_mask(sequence_length=length, max_len=target.size(1)) num_items = mask.sum() loss = functional.binary_cross_entropy_with_logits( - x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum" + x.masked_select(mask), + target.masked_select(mask), + pos_weight=self.pos_weight.to(x.device), + reduction="sum", ) else: - loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum") + loss = functional.binary_cross_entropy_with_logits( + x, target, pos_weight=self.pos_weight.to(x.device), reduction="sum" + ) num_items = torch.numel(x) loss = loss / num_items return loss diff --git a/tests/tts_tests/test_tacotron_model.py b/tests/tts_tests/test_tacotron_model.py index 07351a6a..906ec3d0 100644 --- a/tests/tts_tests/test_tacotron_model.py +++ b/tests/tts_tests/test_tacotron_model.py @@ -16,7 +16,7 @@ from TTS.utils.audio import AudioProcessor torch.manual_seed(1) use_cuda = torch.cuda.is_available() -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if use_cuda else "cpu") config_global = TacotronConfig(num_chars=32, num_speakers=5, out_channels=513, decoder_output_dim=80) @@ -288,7 +288,6 @@ class TacotronCapacitronTrainTest(unittest.TestCase): batch["text_input"].shape[0], batch["stop_targets"].size(1) // config.r, -1 ) batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze() - model = Tacotron(config).to(device) criterion = model.get_criterion() optimizer = model.get_optimizer() From 17ac188958e71fd9da6ade4cc8b8a51be9fc26fe Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 26 Jun 2023 19:27:48 +0200 Subject: [PATCH 24/29] Drop fairseq for Hubert --- TTS/encoder/utils/visual.py | 2 +- TTS/tts/configs/bark_config.py | 10 +++---- TTS/tts/layers/bark/hubert/kmeans_hubert.py | 33 +++++---------------- TTS/tts/layers/bark/inference_funcs.py | 2 +- TTS/tts/utils/helpers.py | 2 +- 5 files changed, 16 insertions(+), 33 deletions(-) diff --git a/TTS/encoder/utils/visual.py b/TTS/encoder/utils/visual.py index f2db2f3f..6575b86e 100644 --- a/TTS/encoder/utils/visual.py +++ b/TTS/encoder/utils/visual.py @@ -23,7 +23,7 @@ colormap = ( [0, 0, 0], [183, 183, 183], ], - dtype=np.float, + dtype=float, ) / 255 ) diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py index 647116bd..4d1cd137 100644 --- a/TTS/tts/configs/bark_config.py +++ b/TTS/tts/configs/bark_config.py @@ -1,5 +1,5 @@ import os -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict from TTS.tts.configs.shared_configs import BaseTTSConfig @@ -46,11 +46,11 @@ class BarkConfig(BaseTTSConfig): """ model: str = "bark" - audio: BarkAudioConfig = BarkAudioConfig() + audio: BarkAudioConfig = field(default_factory=BarkAudioConfig) num_chars: int = 0 - semantic_config: GPTConfig = GPTConfig() - fine_config: FineGPTConfig = FineGPTConfig() - coarse_config: GPTConfig = GPTConfig() + semantic_config: GPTConfig = field(default_factory=GPTConfig) + fine_config: FineGPTConfig = field(default_factory=FineGPTConfig) + coarse_config: GPTConfig = field(default_factory=GPTConfig) CONTEXT_WINDOW_SIZE: int = 1024 SEMANTIC_RATE_HZ: float = 49.9 SEMANTIC_VOCAB_SIZE: int = 10_000 diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py index 7c667755..ee544ee1 100644 --- a/TTS/tts/layers/bark/hubert/kmeans_hubert.py +++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -10,11 +10,11 @@ License: MIT import logging from pathlib import Path -import fairseq import torch from einops import pack, unpack from torch import nn from torchaudio.functional import resample +from transformers import HubertModel logging.root.setLevel(logging.ERROR) @@ -49,22 +49,11 @@ class CustomHubert(nn.Module): self.target_sample_hz = target_sample_hz self.seq_len_multiple_of = seq_len_multiple_of self.output_layer = output_layer - if device is not None: self.to(device) - - model_path = Path(checkpoint_path) - - assert model_path.exists(), f"path {checkpoint_path} does not exist" - - checkpoint = torch.load(checkpoint_path) - load_model_input = {checkpoint_path: checkpoint} - model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) - + self.model = HubertModel.from_pretrained("facebook/hubert-base-ls960") if device is not None: - model[0].to(device) - - self.model = model[0] + self.model.to(device) self.model.eval() @property @@ -81,19 +70,13 @@ class CustomHubert(nn.Module): if exists(self.seq_len_multiple_of): wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) - embed = self.model( + outputs = self.model.forward( wav_input, - features_only=True, - mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code - output_layer=self.output_layer, + output_hidden_states=True, ) - - embed, packed_shape = pack([embed["x"]], "* d") - - # codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy()) - - codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long() - + embed = outputs["hidden_states"][self.output_layer] + embed, packed_shape = pack([embed], "* d") + codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) if flatten: return codebook_indices diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py index fa7a1ebf..da962ab1 100644 --- a/TTS/tts/layers/bark/inference_funcs.py +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -130,7 +130,7 @@ def generate_voice( # generate semantic tokens # Load the HuBERT model hubert_manager = HubertManager() - hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"]) + # hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"]) hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]) hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device) diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 56ef2944..c6d1ec2c 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -207,7 +207,7 @@ def maximum_path_numpy(value, mask, max_neg_val=None): device = value.device dtype = value.dtype value = value.cpu().detach().numpy() - mask = mask.cpu().detach().numpy().astype(np.bool) + mask = mask.cpu().detach().numpy().astype(bool) b, t_x, t_y = value.shape direction = np.zeros(value.shape, dtype=np.int64) From a13b1352a4980bc0809937c2865d894165d7774f Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 26 Jun 2023 19:30:26 +0200 Subject: [PATCH 25/29] Fixup --- TTS/tts/configs/tortoise_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/tts/configs/tortoise_config.py b/TTS/tts/configs/tortoise_config.py index 7da94a4c..d60e43d7 100644 --- a/TTS/tts/configs/tortoise_config.py +++ b/TTS/tts/configs/tortoise_config.py @@ -70,7 +70,7 @@ class TortoiseConfig(BaseTTSConfig): model: str = "tortoise" # model specific params model_args: TortoiseArgs = field(default_factory=TortoiseArgs) - audio: TortoiseAudioConfig = TortoiseAudioConfig() + audio: TortoiseAudioConfig = field(default_factory=TortoiseAudioConfig) model_dir: str = None # settings From d659dbe3c625669c8d5c2ceea50e3356837f014c Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 26 Jun 2023 19:31:56 +0200 Subject: [PATCH 26/29] Remove fairseq --- requirements.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 049a6660..6e3d1254 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,6 +50,4 @@ k_diffusion einops transformers #deps for bark -encodec -#deps for fairseq models -fairseq +encodec \ No newline at end of file From 3933b47f3322540442f089f5ae1779c62945b712 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Tue, 27 Jun 2023 00:08:06 +0200 Subject: [PATCH 27/29] Fixup --- tests/text_tests/test_tokenizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/text_tests/test_tokenizer.py b/tests/text_tests/test_tokenizer.py index 6e95c0ad..dfa213d9 100644 --- a/tests/text_tests/test_tokenizer.py +++ b/tests/text_tests/test_tokenizer.py @@ -1,5 +1,5 @@ import unittest -from dataclasses import dataclass +from dataclasses import dataclass, field from coqpit import Coqpit @@ -86,11 +86,11 @@ class TestTTSTokenizer(unittest.TestCase): enable_eos_bos_chars: bool = True use_phonemes: bool = True add_blank: bool = False - characters: str = Characters() + characters: str = field(default_factory=Characters) phonemizer: str = "espeak" phoneme_language: str = "tr" text_cleaner: str = "phoneme_cleaners" - characters = Characters() + characters = field(default_factory=Characters) tokenizer_ph, _ = TTSTokenizer.init_from_config(TokenizerConfig()) tokenizer_ph.phonemizer.backend = "espeak" From f6fa1dbc9fac24d1924f0207134c001847627086 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Tue, 27 Jun 2023 15:01:52 +0200 Subject: [PATCH 28/29] Fix sed --- .github/workflows/zoo_tests1.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/zoo_tests1.yml b/.github/workflows/zoo_tests1.yml index 15429351..00f13397 100644 --- a/.github/workflows/zoo_tests1.yml +++ b/.github/workflows/zoo_tests1.yml @@ -43,6 +43,7 @@ jobs: run: python3 -m pip install --upgrade pip setuptools wheel - name: Replace scarf urls run: | + sed -i 's/https:\/\/coqui.gateway.scarf.sh\/hf\/bark\//https:\/\/huggingface.co\/erogol\/bark\/resolve\/main\//g' TTS/.models.json sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json - name: Install TTS run: | From 4786548287329c69ae76672ab5953a0aa4021d78 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Wed, 28 Jun 2023 11:24:45 +0200 Subject: [PATCH 29/29] Prevent running bark test on CI --- tests/zoo_tests/test_models.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index 001f5ef6..d3a83980 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -15,7 +15,7 @@ def run_models(offset=0, step=1): print(" > Run synthesizer with all the models.") output_path = os.path.join(get_tests_output_path(), "output.wav") manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False) - model_names = manager.list_models() + model_names = [name for name in manager.list_models() if "bark" not in name] for model_name in model_names[offset::step]: print(f"\n > Run - {model_name}") model_path, _, _ = manager.download_model(model_name) @@ -79,6 +79,15 @@ def test_models_offset_2_step_3(): run_models(offset=2, step=3) +def test_bark(): + """Bark is too big to run on github actions. We need to test it locally""" + output_path = os.path.join(get_tests_output_path(), "output.wav") + run_cli( + f" tts --model_name tts_models/multilingual/multi-dataset/bark " + f'--text "This is an example." --out_path "{output_path}" --progress_bar False' + ) + + def test_voice_conversion(): print(" > Run voice conversion inference using YourTTS model.") model_name = "tts_models/multilingual/multi-dataset/your_tts"