add ability to pass tortoise presets through coqui api

This commit is contained in:
manmay-nakhashi 2023-05-02 00:24:49 +05:30
parent 8c739fd5f2
commit fede89ac0d
6 changed files with 53 additions and 26 deletions

View File

@ -504,6 +504,7 @@ class TTS:
speaker_wav: str = None, speaker_wav: str = None,
emotion: str = None, emotion: str = None,
speed: float = None, speed: float = None,
**kwargs,
): ):
"""Convert text to speech. """Convert text to speech.
@ -540,6 +541,7 @@ class TTS:
style_wav=None, style_wav=None,
style_text=None, style_text=None,
reference_speaker_name=None, reference_speaker_name=None,
**kwargs,
) )
return wav return wav
@ -552,6 +554,7 @@ class TTS:
emotion: str = "Neutral", emotion: str = "Neutral",
speed: float = 1.0, speed: float = 1.0,
file_path: str = "output.wav", file_path: str = "output.wav",
**kwargs,
): ):
"""Convert text to speech. """Convert text to speech.
@ -580,7 +583,7 @@ class TTS:
return self.tts_coqui_studio( return self.tts_coqui_studio(
text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed, file_path=file_path text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed, file_path=file_path
) )
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav) wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
self.synthesizer.save_wav(wav=wav, path=file_path) self.synthesizer.save_wav(wav=wav, path=file_path)
return file_path return file_path

View File

@ -1,8 +1,8 @@
import functools import functools
import math import math
import os import os
import fsspec
import fsspec
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -2,8 +2,8 @@ import os
import torch import torch
from tokenizers import Tokenizer from tokenizers import Tokenizer
from TTS.tts.utils.text.cleaners import english_cleaners
from TTS.tts.utils.text.cleaners import english_cleaners
DEFAULT_VOCAB_FILE = os.path.join( DEFAULT_VOCAB_FILE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json" os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"

View File

@ -5,20 +5,17 @@ from tqdm import tqdm
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models") DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
MODELS_DIR = "/data/speech_synth/models/"
MODELS = { MODELS = {
"autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth", "autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth",
"classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth", "classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth",
"clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth", "clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth",
"cvvp.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth",
"diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth", "diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth",
"vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth", "vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth",
"rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth", "rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth",
"rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth", "rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth",
} }
pbar = None
def download_models(specific_models=None): def download_models(specific_models=None):
""" """

View File

@ -11,13 +11,7 @@ from coqpit import Coqpit
from tqdm import tqdm from tqdm import tqdm
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.layers.tortoise.audio_utils import ( from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, load_voice, wav_to_univnet_mel
denormalize_tacotron_mel,
load_audio,
load_voice,
load_voices,
wav_to_univnet_mel,
)
from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
from TTS.tts.layers.tortoise.clvp import CLVP from TTS.tts.layers.tortoise.clvp import CLVP
@ -25,7 +19,7 @@ from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_sc
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path from TTS.tts.layers.tortoise.utils import get_model_path
from TTS.tts.layers.tortoise.vocoder import VocConf from TTS.tts.layers.tortoise.vocoder import VocConf
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
@ -498,16 +492,11 @@ class Tortoise(BaseTTS):
with torch.no_grad(): with torch.no_grad():
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
def synthesis( def synthesis(self, text, config, speaker_id="lj", **kwargs):
self,
text,
config,
speaker_id="lj",
):
voice_samples, conditioning_latents = load_voice(speaker_id) voice_samples, conditioning_latents = load_voice(speaker_id)
outputs = self.inference_with_config( outputs = self.inference_with_config(
text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents, **kwargs
) )
return_dict = { return_dict = {
@ -533,10 +522,47 @@ class Tortoise(BaseTTS):
"top_p": config.top_p, "top_p": config.top_p,
"cond_free_k": config.cond_free_k, "cond_free_k": config.cond_free_k,
"diffusion_temperature": config.diffusion_temperature, "diffusion_temperature": config.diffusion_temperature,
"num_autoregressive_samples": config.num_autoregressive_samples,
"diffusion_iterations": config.diffusion_iterations,
"sampler": config.sampler, "sampler": config.sampler,
} }
# Presets are defined here.
presets = {
"single_sample": {
"num_autoregressive_samples": 8,
"diffusion_iterations": 10,
"sampler": "ddim",
},
"ultra_fast": {
"num_autoregressive_samples": 16,
"diffusion_iterations": 10,
"sampler": "ddim",
},
"ultra_fast_old": {
"num_autoregressive_samples": 16,
"diffusion_iterations": 30,
"cond_free": False,
},
"very_fast": {
"num_autoregressive_samples": 32,
"diffusion_iterations": 30,
"sampler": "dpm++2m",
},
"fast": {
"num_autoregressive_samples": 5,
"diffusion_iterations": 50,
"sampler": "ddim",
},
"fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
"standard": {
"num_autoregressive_samples": 5,
"diffusion_iterations": 200,
},
"high_quality": {
"num_autoregressive_samples": 256,
"diffusion_iterations": 400,
},
}
settings.update(presets[kwargs["preset"]])
kwargs.pop("preset")
settings.update(kwargs) # allow overriding of preset settings with kwargs settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.inference(text, **settings) return self.inference(text, **settings)
@ -551,7 +577,7 @@ class Tortoise(BaseTTS):
return_deterministic_state=False, return_deterministic_state=False,
latent_averaging_mode=0, latent_averaging_mode=0,
# autoregressive generation parameters follow # autoregressive generation parameters follow
num_autoregressive_samples=512, num_autoregressive_samples=16,
temperature=0.8, temperature=0.8,
length_penalty=1, length_penalty=1,
repetition_penalty=2.0, repetition_penalty=2.0,

View File

@ -228,6 +228,7 @@ class Synthesizer(object):
style_text=None, style_text=None,
reference_wav=None, reference_wav=None,
reference_speaker_name=None, reference_speaker_name=None,
**kwargs,
) -> List[int]: ) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech. """🐸 TTS magic. Run all the models and generate speech.
@ -328,7 +329,7 @@ class Synthesizer(object):
if not reference_wav: if not reference_wav:
for sen in sens: for sen in sens:
if self.tts_config.model == "tortoise": if self.tts_config.model == "tortoise":
outputs = self.tts_model.synthesis(text=sen, config=self.tts_config) outputs = self.tts_model.synthesis(text=sen, config=self.tts_config, **kwargs)
else: else:
# synthesize voice # synthesize voice
outputs = synthesis( outputs = synthesis(