From 1f77cc6cac22a7dc7b15de0e22b9600fd015fe1a Mon Sep 17 00:00:00 2001 From: manmay-nakhashi Date: Mon, 24 Apr 2023 01:10:52 +0530 Subject: [PATCH] added tortoise config and updated config and args, refactoring the code --- TTS/tts/configs/tortoise.py | 0 TTS/tts/configs/tortoise_config.py | 46 ++++ TTS/tts/models/tortoise.py | 347 +++++++++++++++++------------ 3 files changed, 253 insertions(+), 140 deletions(-) delete mode 100644 TTS/tts/configs/tortoise.py create mode 100644 TTS/tts/configs/tortoise_config.py diff --git a/TTS/tts/configs/tortoise.py b/TTS/tts/configs/tortoise.py deleted file mode 100644 index e69de29b..00000000 diff --git a/TTS/tts/configs/tortoise_config.py b/TTS/tts/configs/tortoise_config.py new file mode 100644 index 00000000..2dd1b8a1 --- /dev/null +++ b/TTS/tts/configs/tortoise_config.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.tortoise import TortoiseArgs, TortoiseAudioConfig + + +@dataclass +class TortoiseConfig(BaseTTSConfig): + """Defines parameters for Tortoise TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (TortoiseArgs): + Model architecture arguments. Defaults to `TortoiseArgs()`. + + audio (TortoiseAudioConfig): + Audio processing configuration. Defaults to `TortoiseAudioConfig()`. + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.vits_config import VitsConfig + >>> config = VitsConfig() + """ + + model: str = "tortoise" + # model specific params + model_args: TortoiseArgs = field(default_factory=TortoiseArgs) + audio: TortoiseAudioConfig = TortoiseAudioConfig() + + # settings + temperature: int = 0.2 + length_penalty: int = 1.0 + repetition_penalty: int = 2.0 + top_p: int = 0.8 + cond_free_k: int = 2.0 + diffusion_temperature: int = 1.0 + + # inference params + num_autoregressive_samples: int = 16 + diffusion_iterations: int = 30 + sampler: str = "ddim" diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index aae02764..5fb07ffb 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -1,17 +1,23 @@ -# ## AGPL: a notification must be added stating that changes have been made to that file. - import os import random from contextlib import contextmanager +from dataclasses import dataclass, field from time import time import torch import torch.nn.functional as F import torchaudio +from coqpit import Coqpit from tqdm import tqdm from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram -from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel +from TTS.tts.layers.tortoise.audio_utils import ( + denormalize_tacotron_mel, + load_audio, + load_voice, + load_voices, + wav_to_univnet_mel, +) from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead from TTS.tts.layers.tortoise.clvp import CLVP @@ -22,6 +28,7 @@ from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path from TTS.tts.layers.tortoise.vocoder import VocConf from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment +from TTS.tts.models.base_tts import BaseTTS def pad_or_truncate(t, length): @@ -191,21 +198,74 @@ def pick_best_batch_size_for_gpu(): return batch_size -class TextToSpeech: +@dataclass +class TortoiseAudioConfig(Coqpit): + sample_rate: int = 22050 + diffusion_sample_rate: int = 24000 + + +@dataclass +class TortoiseArgs(Coqpit): + autoregressive_batch_size: int = 1 + enable_redaction: bool = True + high_vram: bool = False + kv_cache: bool = True + ar_checkpoint: str = None + clvp_checkpoint: str = None + diff_checkpoint: str = None + num_chars: int = 255 + + # UnifiedVoice params + ar_max_mel_tokens: int = 604 + ar_max_text_tokens: int = 402 + ar_max_conditioning_inputs: int = 2 + ar_layers: int = 30 + ar_model_dim: int = 1024 + ar_heads: int = 16 + ar_number_text_tokens: int = 255 + ar_start_text_token: int = 255 + ar_checkpointing: bool = False + ar_train_solo_embeddings: bool = False + + # DiffTTS params + diff_model_channels: int = 1024 + diff_num_layers: int = 10 + diff_in_channels: int = 100 + diff_out_channels: int = 200 + diff_in_latent_channels: int = 1024 + diff_in_tokens: int = 8193 + diff_dropout: int = 0 + diff_use_fp16: bool = False + diff_num_heads: int = 16 + diff_layer_drop: int = 0 + diff_unconditioned_percentage: int = 0 + + # clvp params + clvp_dim_text: int = 768 + clvp_dim_speech: int = 768 + clvp_dim_latent: int = 768 + clvp_num_text_tokens: int = 256 + clvp_text_enc_depth: int = 20 + clvp_text_seq_len: int = 350 + clvp_text_heads: int = 12 + clvp_num_speech_tokens: int = 8192 + clvp_speech_enc_depth: int = 20 + clvp_speech_heads: int = 12 + clvp_speech_seq_len: int = 430 + clvp_use_xformers: bool = True + # constants + duration_const: int = 102400 + + +class Tortoise(BaseTTS): """ Main entry point into Tortoise. """ def __init__( self, - autoregressive_batch_size=None, + config: Coqpit, models_dir=MODELS_DIR, - enable_redaction=True, - high_vram=False, - kv_cache=True, - ar_checkpoint=None, - clvp_checkpoint=None, - diff_checkpoint=None, vocoder=VocConf.Univnet, ): """ @@ -224,82 +284,82 @@ class TextToSpeech: :param clvp_checkpoint: Path to a checkpoint file for the CLVP model. If omitted, uses default :param diff_checkpoint: Path to a checkpoint file for the diffusion model. If omitted, uses default """ - self.ar_checkpoint = ar_checkpoint - self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed + super().__init__(config, ap=None, tokenizer=None) + self.config = config + self.ar_checkpoint = self.args.ar_checkpoint + self.diff_checkpoint = self.args.diff_checkpoint # TODO: check if this is even needed self.models_dir = models_dir self.autoregressive_batch_size = ( - pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size + pick_best_batch_size_for_gpu() + if self.args.autoregressive_batch_size is None + else self.args.autoregressive_batch_size ) - self.enable_redaction = enable_redaction + self.enable_redaction = self.args.enable_redaction self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if self.enable_redaction: self.aligner = Wav2VecAlignment() self.tokenizer = VoiceBpeTokenizer() - if os.path.exists(f"{models_dir}/autoregressive.ptt"): - # Assume this is a traced directory. - self.autoregressive = torch.jit.load(f"{models_dir}/autoregressive.ptt") - self.diffusion = torch.jit.load(f"{models_dir}/diffusion_decoder.ptt") - else: - self.autoregressive = ( - UnifiedVoice( - max_mel_tokens=604, - max_text_tokens=402, - max_conditioning_inputs=2, - layers=30, - model_dim=1024, - heads=16, - number_text_tokens=255, - start_text_token=255, - checkpointing=False, - train_solo_embeddings=False, - ) - .cpu() - .eval() - ) - ar_path = ar_checkpoint or get_model_path("autoregressive.pth", models_dir) - self.autoregressive.load_state_dict(torch.load(ar_path)) - self.autoregressive.post_init_gpt2_config(kv_cache) - - diff_path = diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir) - self.diffusion = ( - DiffusionTts( - model_channels=1024, - num_layers=10, - in_channels=100, - out_channels=200, - in_latent_channels=1024, - in_tokens=8193, - dropout=0, - use_fp16=False, - num_heads=16, - layer_drop=0, - unconditioned_percentage=0, - ) - .cpu() - .eval() - ) - self.diffusion.load_state_dict(torch.load(diff_path)) - self.clvp = ( - CLVP( - dim_text=768, - dim_speech=768, - dim_latent=768, - num_text_tokens=256, - text_enc_depth=20, - text_seq_len=350, - text_heads=12, - num_speech_tokens=8192, - speech_enc_depth=20, - speech_heads=12, - speech_seq_len=430, - use_xformers=True, + self.autoregressive = ( + UnifiedVoice( + max_mel_tokens=self.args.ar_max_mel_tokens, + max_text_tokens=self.args.ar_max_text_tokens, + max_conditioning_inputs=self.args.ar_max_conditioning_inputs, + layers=self.args.ar_layers, + model_dim=self.args.ar_model_dim, + heads=self.args.ar_heads, + number_text_tokens=self.args.ar_number_text_tokens, + start_text_token=self.args.ar_start_text_token, + checkpointing=self.args.ar_checkpointing, + train_solo_embeddings=self.args.ar_train_solo_embeddings, ) .cpu() .eval() ) - clvp_path = clvp_checkpoint or get_model_path("clvp2.pth", models_dir) + ar_path = self.args.ar_checkpoint or get_model_path("autoregressive.pth", models_dir) + self.autoregressive.load_state_dict(torch.load(ar_path)) + self.autoregressive.post_init_gpt2_config(self.args.kv_cache) + + diff_path = self.args.diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir) + self.diffusion = ( + DiffusionTts( + model_channels=self.args.diff_model_channels, + num_layers=self.args.diff_num_layers, + in_channels=self.args.diff_in_channels, + out_channels=self.args.diff_out_channels, + in_latent_channels=self.args.diff_in_latent_channels, + in_tokens=self.args.diff_in_tokens, + dropout=self.args.diff_dropout, + use_fp16=self.args.diff_use_fp16, + num_heads=self.args.diff_num_heads, + layer_drop=self.args.diff_layer_drop, + unconditioned_percentage=self.args.diff_unconditioned_percentage, + ) + .cpu() + .eval() + ) + self.diffusion.load_state_dict(torch.load(diff_path)) + + self.clvp = ( + CLVP( + dim_text=self.args.clvp_dim_text, + dim_speech=self.args.clvp_dim_speech, + dim_latent=self.args.clvp_dim_latent, + num_text_tokens=self.args.clvp_num_text_tokens, + text_enc_depth=self.args.clvp_text_enc_depth, + text_seq_len=self.args.clvp_text_seq_len, + text_heads=self.args.clvp_text_heads, + num_speech_tokens=self.args.clvp_num_speech_tokens, + speech_enc_depth=self.args.clvp_speech_enc_depth, + speech_heads=self.args.clvp_speech_heads, + speech_seq_len=self.args.clvp_speech_seq_len, + use_xformers=self.args.clvp_use_xformers, + ) + .cpu() + .eval() + ) + clvp_path = self.args.clvp_checkpoint or get_model_path("clvp2.pth", models_dir) self.clvp.load_state_dict(torch.load(clvp_path)) self.vocoder = vocoder.value.constructor().cpu() @@ -317,12 +377,12 @@ class TextToSpeech: self.rlg_auto = None self.rlg_diffusion = None - if high_vram: + if self.args.high_vram: self.autoregressive = self.autoregressive.to(self.device) self.diffusion = self.diffusion.to(self.device) self.clvp = self.clvp.to(self.device) self.vocoder = self.vocoder.to(self.device) - self.high_vram = high_vram + self.high_vram = self.args.high_vram @contextmanager def temporary_cuda(self, model): @@ -355,7 +415,7 @@ class TextToSpeech: 1, 2, ], "latent_averaging mode has to be one of (0, 1, 2)" - print("mode", latent_averaging_mode) + with torch.no_grad(): voice_samples = [[v.to(self.device) for v in ls] for ls in voice_samples] @@ -368,7 +428,7 @@ class TextToSpeech: diffusion_conds = [] - DURS_CONST = 102400 + DURS_CONST = self.args.duration_const for ls in voice_samples: # The diffuser operates at a sample rate of 24000 (except for the latent inputs) sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1] @@ -428,67 +488,51 @@ class TextToSpeech: with torch.no_grad(): return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) - def tts_with_preset(self, text, preset="fast", **kwargs): + def synthesis( + self, + model, + text, + config, + preset, + speaker_id="lj", + ): + voice_samples, conditioning_latents = load_voice(speaker_id) + + outputs = model.inference_with_config( + text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents + ) + + return_dict = { + "wav": outputs["wav"], + "deterministic_seed": outputs["deterministic_seed"], + "text_inputs": outputs["text"], + "voice_samples": outputs["voice_samples"], + "conditioning_latents": outputs["conditioning_latents"], + } + + return return_dict + + def inference_with_config(self, text, config, **kwargs): """ - Calls TTS with one of a set of preset generation parameters. Options: - 'single_sample': Produces speech even faster, but only produces 1 sample. - 'ultra_fast': Produces speech much faster than the original tortoise repo. - 'ultra_fast_old': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest). - 'fast': Decent quality speech at a decent inference rate. A good choice for mass inference. - 'standard': Very good quality. This is generally about as good as you are going to get. - 'high_quality': Use if you want the absolute best. This is not really worth the compute, though. + inference with config + #TODO describe in detail """ # Use generally found best tuning knobs for generation. settings = { - "temperature": 0.2, - "length_penalty": 1.0, - "repetition_penalty": 2.0, - "top_p": 0.8, - "cond_free_k": 2.0, - "diffusion_temperature": 1.0, + "temperature": config.temperature, + "length_penalty": config.length_penalty, + "repetition_penalty": config.repetition_penalty, + "top_p": config.top_p, + "cond_free_k": config.cond_free_k, + "diffusion_temperature": config.diffusion_temperature, + "num_autoregressive_samples": config.num_autoregressive_samples, + "diffusion_iterations": config.diffusion_iterations, + "sampler": config.sampler, } - # Presets are defined here. - presets = { - "single_sample": { - "num_autoregressive_samples": 8, - "diffusion_iterations": 10, - "sampler": "ddim", - }, - "ultra_fast": { - "num_autoregressive_samples": 16, - "diffusion_iterations": 10, - "sampler": "ddim", - }, - "ultra_fast_old": { - "num_autoregressive_samples": 16, - "diffusion_iterations": 30, - "cond_free": False, - }, - "very_fast": { - "num_autoregressive_samples": 32, - "diffusion_iterations": 30, - "sampler": "dpm++2m", - }, - "fast": { - "num_autoregressive_samples": 16, - "diffusion_iterations": 50, - "sampler": "ddim", - }, - "fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80}, - "standard": { - "num_autoregressive_samples": 256, - "diffusion_iterations": 200, - }, - "high_quality": { - "num_autoregressive_samples": 256, - "diffusion_iterations": 400, - }, - } - settings.update(presets[preset]) settings.update(kwargs) # allow overriding of preset settings with kwargs - return self.tts(text, **settings) + return self.inference(text, **settings) - def tts( + def inference( self, text, voice_samples=None, @@ -709,11 +753,34 @@ class TextToSpeech: else: res = wav_candidates[0] - if return_deterministic_state: - return res, ( - deterministic_seed, - text, - voice_samples, - conditioning_latents, - ) - return res + return_dict = { + "wav": res, + "deterministic_seed": None, + "text": None, + "voice_samples": None, + "conditioning_latents": None, + } + if return_deterministic_state: + return_dict = { + "wav": res, + "deterministic_seed": deterministic_seed, + "text": text, + "voice_samples": voice_samples, + "conditioning_latents": conditioning_latents, + } + return return_dict + + def forward(): + raise NotImplementedError("Tortoise Training is not implemented") + + def eval_step(): + raise NotImplementedError("Tortoise Training is not implemented") + + def init_from_config(): + raise NotImplementedError("Tortoise Training is not implemented") + + def load_checkpoint(): + raise NotImplementedError("Tortoise Training is not implemented") + + def train_step(): + raise NotImplementedError("Tortoise Training is not implemented")