diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py new file mode 100644 index 00000000..260e504a --- /dev/null +++ b/TTS/tts/models/bark.py @@ -0,0 +1,277 @@ +import os +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from coqpit import Coqpit +from encodec import EncodecModel +from transformers import BertTokenizer + +from TTS.tts.layers.bark.inference_funcs import ( + codec_decode, + generate_coarse, + generate_fine, + generate_text_semantic, + generate_voice, + load_voice, +) +from TTS.tts.layers.bark.load_model import load_model +from TTS.tts.layers.bark.model import GPT +from TTS.tts.layers.bark.model_fine import FineGPT +from TTS.tts.models.base_tts import BaseTTS + + +@dataclass +class BarkAudioConfig(Coqpit): + sample_rate: int = 24000 + output_sample_rate: int = 24000 + + +class Bark(BaseTTS): + def __init__( + self, + config: Coqpit, + tokenizer: BertTokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased"), + ) -> None: + super().__init__(config=config, ap=None, tokenizer=None, speaker_manager=None, language_manager=None) + self.config.num_chars = len(tokenizer) + self.tokenizer = tokenizer + self.semantic_model = GPT(config.semantic_config) + self.coarse_model = GPT(config.coarse_config) + self.fine_model = FineGPT(config.fine_config) + self.encodec = EncodecModel.encodec_model_24khz() + self.encodec.set_target_bandwidth(6.0) + + @property + def device(self): + return next(self.parameters()).device + + def load_bark_models(self): + self.semantic_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["text"], device=self.device, config=self.config, model_type="text" + ) + self.coarse_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["coarse"], + device=self.device, + config=self.config, + model_type="coarse", + ) + self.fine_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["fine"], device=self.device, config=self.config, model_type="fine" + ) + + def train_step( + self, + ): + pass + + def text_to_semantic( + self, + text: str, + history_prompt: Optional[str] = None, + temp: float = 0.7, + base=None, + allow_early_stop=True, + **kwargs, + ): + """Generate semantic array from text. + + Args: + text: text to be turned into audio + history_prompt: history choice for audio cloning + temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy semantic array to be fed into `semantic_to_waveform` + """ + x_semantic = generate_text_semantic( + text, + self, + history_prompt=history_prompt, + temp=temp, + base=base, + allow_early_stop=allow_early_stop, + **kwargs, + ) + return x_semantic + + def semantic_to_waveform( + self, + semantic_tokens: np.ndarray, + history_prompt: Optional[str] = None, + temp: float = 0.7, + base=None, + ): + """Generate audio array from semantic input. + + Args: + semantic_tokens: semantic token output from `text_to_semantic` + history_prompt: history choice for audio cloning + temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy audio array at sample frequency 24khz + """ + x_coarse_gen = generate_coarse( + semantic_tokens, + self, + history_prompt=history_prompt, + temp=temp, + base=base, + ) + x_fine_gen = generate_fine( + x_coarse_gen, + self, + history_prompt=history_prompt, + temp=0.5, + base=base, + ) + audio_arr = codec_decode(x_fine_gen, self) + return audio_arr, x_coarse_gen, x_fine_gen + + def generate_audio( + self, + text: str, + history_prompt: Optional[str] = None, + text_temp: float = 0.7, + waveform_temp: float = 0.7, + base=None, + allow_early_stop=True, + **kwargs, + ): + """Generate audio array from input text. + + Args: + text: text to be turned into audio + history_prompt: history choice for audio cloning + text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) + waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy audio array at sample frequency 24khz + """ + x_semantic = self.text_to_semantic( + text, + history_prompt=history_prompt, + temp=text_temp, + base=base, + allow_early_stop=allow_early_stop, + **kwargs, + ) + audio_arr, c, f = self.semantic_to_waveform( + x_semantic, history_prompt=history_prompt, temp=waveform_temp, base=base + ) + return audio_arr, [x_semantic, c, f] + + def generate_voice(self, audio, speaker_id, voice_dir): + """Generate a voice from the given audio and text. + + Args: + audio (str): Path to the audio file. + speaker_id (str): Speaker name. + voice_dir (str): Path to the directory to save the generate voice. + """ + if voice_dir is not None: + voice_dirs = [voice_dir] + try: + _ = load_voice(speaker_id, voice_dirs) + except (KeyError, FileNotFoundError): + output_path = os.path.join(voice_dir, speaker_id + ".npz") + os.makedirs(voice_dir, exist_ok=True) + generate_voice(audio, self, output_path) + + def _set_voice_dirs(self, voice_dirs): + def_voice_dir = None + if isinstance(self.config.DEF_SPEAKER_DIR, str): + os.makedirs(self.config.DEF_SPEAKER_DIR, exist_ok=True) + if os.path.isdir(self.config.DEF_SPEAKER_DIR): + def_voice_dir = self.config.DEF_SPEAKER_DIR + _voice_dirs = [def_voice_dir] if def_voice_dir is not None else [] + if voice_dirs is not None: + if isinstance(voice_dirs, str): + voice_dirs = [voice_dirs] + _voice_dirs = voice_dirs + _voice_dirs + return _voice_dirs + + # TODO: remove config from synthesize + def synthesize( + self, text, config, speaker_id="random", voice_dirs=None, **kwargs + ): # pylint: disable=unused-argument + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (BarkConfig): Config with inference parameters. + speaker_id (str): One of the available speaker names. If `random`, it generates a random speaker. + speaker_wav (str): Path to the speaker audio file for cloning a new voice. It is cloned and saved in + `voice_dirs` with the name `speaker_id`. Defaults to None. + voice_dirs (List[str]): List of paths that host reference audio files for speakers. Defaults to None. + **kwargs: Inference settings. See `inference()`. + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + voice_dirs = self._set_voice_dirs(voice_dirs) + history_prompt = load_voice(self, speaker_id, voice_dirs) + outputs = self.generate_audio(text, history_prompt=history_prompt, **kwargs) + return_dict = { + "wav": outputs[0], + "text_inputs": text, + } + + return return_dict + + def eval_step(self): + ... + + def forward(self): + ... + + def inference(self): + ... + + @staticmethod + def init_from_config(config: "BarkConfig", **kwargs): # pylint: disable=unused-argument + return Bark(config) + + # pylint: disable=unused-argument, redefined-builtin + def load_checkpoint( + self, + config, + checkpoint_dir, + text_model_path=None, + coarse_model_path=None, + fine_model_path=None, + eval=False, + strict=True, + **kwargs, + ): + """Load a model checkpoints from a directory. This model is with multiple checkpoint files and it + expects to have all the files to be under the given `checkpoint_dir` with the rigth names. + If eval is True, set the model to eval mode. + + Args: + config (TortoiseConfig): The model config. + checkpoint_dir (str): The directory where the checkpoints are stored. + ar_checkpoint_path (str, optional): The path to the autoregressive checkpoint. Defaults to None. + diff_checkpoint_path (str, optional): The path to the diffusion checkpoint. Defaults to None. + clvp_checkpoint_path (str, optional): The path to the CLVP checkpoint. Defaults to None. + vocoder_checkpoint_path (str, optional): The path to the vocoder checkpoint. Defaults to None. + eval (bool, optional): Whether to set the model to eval mode. Defaults to False. + strict (bool, optional): Whether to load the model strictly. Defaults to True. + """ + text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt") + coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt") + fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt") + + self.config.LOCAL_MODEL_PATHS["text"] = text_model_path + self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path + self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path + + self.load_bark_models() + + if eval: + self.eval()