From fede89ac0d1b350c0ebdedcea052bcc9399258d5 Mon Sep 17 00:00:00 2001 From: manmay-nakhashi Date: Tue, 2 May 2023 00:24:49 +0530 Subject: [PATCH] add ability to pass tortoise presets through coqui api --- TTS/api.py | 5 ++- TTS/tts/layers/tortoise/arch_utils.py | 2 +- TTS/tts/layers/tortoise/tokenizer.py | 2 +- TTS/tts/layers/tortoise/utils.py | 5 +-- TTS/tts/models/tortoise.py | 62 +++++++++++++++++++-------- TTS/utils/synthesizer.py | 3 +- 6 files changed, 53 insertions(+), 26 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 34b29905..554fec80 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -504,6 +504,7 @@ class TTS: speaker_wav: str = None, emotion: str = None, speed: float = None, + **kwargs, ): """Convert text to speech. @@ -540,6 +541,7 @@ class TTS: style_wav=None, style_text=None, reference_speaker_name=None, + **kwargs, ) return wav @@ -552,6 +554,7 @@ class TTS: emotion: str = "Neutral", speed: float = 1.0, file_path: str = "output.wav", + **kwargs, ): """Convert text to speech. @@ -580,7 +583,7 @@ class TTS: return self.tts_coqui_studio( text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed, file_path=file_path ) - wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav) + wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs) self.synthesizer.save_wav(wav=wav, path=file_path) return file_path diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py index 399d67fe..e5cae3be 100644 --- a/TTS/tts/layers/tortoise/arch_utils.py +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -1,8 +1,8 @@ import functools import math import os -import fsspec +import fsspec import torch import torch.nn as nn import torch.nn.functional as F diff --git a/TTS/tts/layers/tortoise/tokenizer.py b/TTS/tts/layers/tortoise/tokenizer.py index 99c8c67b..3e544ee7 100644 --- a/TTS/tts/layers/tortoise/tokenizer.py +++ b/TTS/tts/layers/tortoise/tokenizer.py @@ -2,8 +2,8 @@ import os import torch from tokenizers import Tokenizer -from TTS.tts.utils.text.cleaners import english_cleaners +from TTS.tts.utils.text.cleaners import english_cleaners DEFAULT_VOCAB_FILE = os.path.join( os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json" diff --git a/TTS/tts/layers/tortoise/utils.py b/TTS/tts/layers/tortoise/utils.py index c65ae8a6..810a9e7f 100644 --- a/TTS/tts/layers/tortoise/utils.py +++ b/TTS/tts/layers/tortoise/utils.py @@ -5,20 +5,17 @@ from tqdm import tqdm DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models") MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) - +MODELS_DIR = "/data/speech_synth/models/" MODELS = { "autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth", "classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth", "clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth", - "cvvp.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth", "diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth", "vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth", "rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth", "rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth", } -pbar = None - def download_models(specific_models=None): """ diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index e86ce268..841013dc 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -11,13 +11,7 @@ from coqpit import Coqpit from tqdm import tqdm from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram -from TTS.tts.layers.tortoise.audio_utils import ( - denormalize_tacotron_mel, - load_audio, - load_voice, - load_voices, - wav_to_univnet_mel, -) +from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, load_voice, wav_to_univnet_mel from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead from TTS.tts.layers.tortoise.clvp import CLVP @@ -25,7 +19,7 @@ from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_sc from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer -from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path +from TTS.tts.layers.tortoise.utils import get_model_path from TTS.tts.layers.tortoise.vocoder import VocConf from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment from TTS.tts.models.base_tts import BaseTTS @@ -498,16 +492,11 @@ class Tortoise(BaseTTS): with torch.no_grad(): return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) - def synthesis( - self, - text, - config, - speaker_id="lj", - ): + def synthesis(self, text, config, speaker_id="lj", **kwargs): voice_samples, conditioning_latents = load_voice(speaker_id) outputs = self.inference_with_config( - text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents + text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents, **kwargs ) return_dict = { @@ -533,10 +522,47 @@ class Tortoise(BaseTTS): "top_p": config.top_p, "cond_free_k": config.cond_free_k, "diffusion_temperature": config.diffusion_temperature, - "num_autoregressive_samples": config.num_autoregressive_samples, - "diffusion_iterations": config.diffusion_iterations, "sampler": config.sampler, } + # Presets are defined here. + presets = { + "single_sample": { + "num_autoregressive_samples": 8, + "diffusion_iterations": 10, + "sampler": "ddim", + }, + "ultra_fast": { + "num_autoregressive_samples": 16, + "diffusion_iterations": 10, + "sampler": "ddim", + }, + "ultra_fast_old": { + "num_autoregressive_samples": 16, + "diffusion_iterations": 30, + "cond_free": False, + }, + "very_fast": { + "num_autoregressive_samples": 32, + "diffusion_iterations": 30, + "sampler": "dpm++2m", + }, + "fast": { + "num_autoregressive_samples": 5, + "diffusion_iterations": 50, + "sampler": "ddim", + }, + "fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80}, + "standard": { + "num_autoregressive_samples": 5, + "diffusion_iterations": 200, + }, + "high_quality": { + "num_autoregressive_samples": 256, + "diffusion_iterations": 400, + }, + } + settings.update(presets[kwargs["preset"]]) + kwargs.pop("preset") settings.update(kwargs) # allow overriding of preset settings with kwargs return self.inference(text, **settings) @@ -551,7 +577,7 @@ class Tortoise(BaseTTS): return_deterministic_state=False, latent_averaging_mode=0, # autoregressive generation parameters follow - num_autoregressive_samples=512, + num_autoregressive_samples=16, temperature=0.8, length_penalty=1, repetition_penalty=2.0, diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 97ce4cdf..8d143180 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -228,6 +228,7 @@ class Synthesizer(object): style_text=None, reference_wav=None, reference_speaker_name=None, + **kwargs, ) -> List[int]: """🐸 TTS magic. Run all the models and generate speech. @@ -328,7 +329,7 @@ class Synthesizer(object): if not reference_wav: for sen in sens: if self.tts_config.model == "tortoise": - outputs = self.tts_model.synthesis(text=sen, config=self.tts_config) + outputs = self.tts_model.synthesis(text=sen, config=self.tts_config, **kwargs) else: # synthesize voice outputs = synthesis(