add ability to pass tortoise presets through coqui api

This commit is contained in:
manmay-nakhashi 2023-05-02 00:24:49 +05:30
parent 8c739fd5f2
commit fede89ac0d
6 changed files with 53 additions and 26 deletions

View File

@ -504,6 +504,7 @@ class TTS:
speaker_wav: str = None,
emotion: str = None,
speed: float = None,
**kwargs,
):
"""Convert text to speech.
@ -540,6 +541,7 @@ class TTS:
style_wav=None,
style_text=None,
reference_speaker_name=None,
**kwargs,
)
return wav
@ -552,6 +554,7 @@ class TTS:
emotion: str = "Neutral",
speed: float = 1.0,
file_path: str = "output.wav",
**kwargs,
):
"""Convert text to speech.
@ -580,7 +583,7 @@ class TTS:
return self.tts_coqui_studio(
text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed, file_path=file_path
)
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav)
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
self.synthesizer.save_wav(wav=wav, path=file_path)
return file_path

View File

@ -1,8 +1,8 @@
import functools
import math
import os
import fsspec
import fsspec
import torch
import torch.nn as nn
import torch.nn.functional as F

View File

@ -2,8 +2,8 @@ import os
import torch
from tokenizers import Tokenizer
from TTS.tts.utils.text.cleaners import english_cleaners
from TTS.tts.utils.text.cleaners import english_cleaners
DEFAULT_VOCAB_FILE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"

View File

@ -5,20 +5,17 @@ from tqdm import tqdm
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
MODELS_DIR = "/data/speech_synth/models/"
MODELS = {
"autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth",
"classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth",
"clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth",
"cvvp.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth",
"diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth",
"vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth",
"rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth",
"rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth",
}
pbar = None
def download_models(specific_models=None):
"""

View File

@ -11,13 +11,7 @@ from coqpit import Coqpit
from tqdm import tqdm
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.layers.tortoise.audio_utils import (
denormalize_tacotron_mel,
load_audio,
load_voice,
load_voices,
wav_to_univnet_mel,
)
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, load_voice, wav_to_univnet_mel
from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
from TTS.tts.layers.tortoise.clvp import CLVP
@ -25,7 +19,7 @@ from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_sc
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
from TTS.tts.layers.tortoise.utils import get_model_path
from TTS.tts.layers.tortoise.vocoder import VocConf
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS
@ -498,16 +492,11 @@ class Tortoise(BaseTTS):
with torch.no_grad():
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
def synthesis(
self,
text,
config,
speaker_id="lj",
):
def synthesis(self, text, config, speaker_id="lj", **kwargs):
voice_samples, conditioning_latents = load_voice(speaker_id)
outputs = self.inference_with_config(
text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents
text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents, **kwargs
)
return_dict = {
@ -533,10 +522,47 @@ class Tortoise(BaseTTS):
"top_p": config.top_p,
"cond_free_k": config.cond_free_k,
"diffusion_temperature": config.diffusion_temperature,
"num_autoregressive_samples": config.num_autoregressive_samples,
"diffusion_iterations": config.diffusion_iterations,
"sampler": config.sampler,
}
# Presets are defined here.
presets = {
"single_sample": {
"num_autoregressive_samples": 8,
"diffusion_iterations": 10,
"sampler": "ddim",
},
"ultra_fast": {
"num_autoregressive_samples": 16,
"diffusion_iterations": 10,
"sampler": "ddim",
},
"ultra_fast_old": {
"num_autoregressive_samples": 16,
"diffusion_iterations": 30,
"cond_free": False,
},
"very_fast": {
"num_autoregressive_samples": 32,
"diffusion_iterations": 30,
"sampler": "dpm++2m",
},
"fast": {
"num_autoregressive_samples": 5,
"diffusion_iterations": 50,
"sampler": "ddim",
},
"fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
"standard": {
"num_autoregressive_samples": 5,
"diffusion_iterations": 200,
},
"high_quality": {
"num_autoregressive_samples": 256,
"diffusion_iterations": 400,
},
}
settings.update(presets[kwargs["preset"]])
kwargs.pop("preset")
settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.inference(text, **settings)
@ -551,7 +577,7 @@ class Tortoise(BaseTTS):
return_deterministic_state=False,
latent_averaging_mode=0,
# autoregressive generation parameters follow
num_autoregressive_samples=512,
num_autoregressive_samples=16,
temperature=0.8,
length_penalty=1,
repetition_penalty=2.0,

View File

@ -228,6 +228,7 @@ class Synthesizer(object):
style_text=None,
reference_wav=None,
reference_speaker_name=None,
**kwargs,
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
@ -328,7 +329,7 @@ class Synthesizer(object):
if not reference_wav:
for sen in sens:
if self.tts_config.model == "tortoise":
outputs = self.tts_model.synthesis(text=sen, config=self.tts_config)
outputs = self.tts_model.synthesis(text=sen, config=self.tts_config, **kwargs)
else:
# synthesize voice
outputs = synthesis(