mirror of https://github.com/coqui-ai/TTS.git
add ability to pass tortoise presets through coqui api
This commit is contained in:
parent
8c739fd5f2
commit
fede89ac0d
|
@ -504,6 +504,7 @@ class TTS:
|
||||||
speaker_wav: str = None,
|
speaker_wav: str = None,
|
||||||
emotion: str = None,
|
emotion: str = None,
|
||||||
speed: float = None,
|
speed: float = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Convert text to speech.
|
"""Convert text to speech.
|
||||||
|
|
||||||
|
@ -540,6 +541,7 @@ class TTS:
|
||||||
style_wav=None,
|
style_wav=None,
|
||||||
style_text=None,
|
style_text=None,
|
||||||
reference_speaker_name=None,
|
reference_speaker_name=None,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return wav
|
return wav
|
||||||
|
|
||||||
|
@ -552,6 +554,7 @@ class TTS:
|
||||||
emotion: str = "Neutral",
|
emotion: str = "Neutral",
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
file_path: str = "output.wav",
|
file_path: str = "output.wav",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Convert text to speech.
|
"""Convert text to speech.
|
||||||
|
|
||||||
|
@ -580,7 +583,7 @@ class TTS:
|
||||||
return self.tts_coqui_studio(
|
return self.tts_coqui_studio(
|
||||||
text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed, file_path=file_path
|
text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed, file_path=file_path
|
||||||
)
|
)
|
||||||
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav)
|
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
|
||||||
self.synthesizer.save_wav(wav=wav, path=file_path)
|
self.synthesizer.save_wav(wav=wav, path=file_path)
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import functools
|
import functools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import fsspec
|
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
|
@ -2,8 +2,8 @@ import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
from TTS.tts.utils.text.cleaners import english_cleaners
|
|
||||||
|
|
||||||
|
from TTS.tts.utils.text.cleaners import english_cleaners
|
||||||
|
|
||||||
DEFAULT_VOCAB_FILE = os.path.join(
|
DEFAULT_VOCAB_FILE = os.path.join(
|
||||||
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"
|
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"
|
||||||
|
|
|
@ -5,20 +5,17 @@ from tqdm import tqdm
|
||||||
|
|
||||||
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
|
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
|
||||||
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
|
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
|
||||||
|
MODELS_DIR = "/data/speech_synth/models/"
|
||||||
MODELS = {
|
MODELS = {
|
||||||
"autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth",
|
"autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth",
|
||||||
"classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth",
|
"classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth",
|
||||||
"clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth",
|
"clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth",
|
||||||
"cvvp.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth",
|
|
||||||
"diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth",
|
"diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth",
|
||||||
"vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth",
|
"vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth",
|
||||||
"rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth",
|
"rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth",
|
||||||
"rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth",
|
"rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth",
|
||||||
}
|
}
|
||||||
|
|
||||||
pbar = None
|
|
||||||
|
|
||||||
|
|
||||||
def download_models(specific_models=None):
|
def download_models(specific_models=None):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -11,13 +11,7 @@ from coqpit import Coqpit
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||||
from TTS.tts.layers.tortoise.audio_utils import (
|
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, load_voice, wav_to_univnet_mel
|
||||||
denormalize_tacotron_mel,
|
|
||||||
load_audio,
|
|
||||||
load_voice,
|
|
||||||
load_voices,
|
|
||||||
wav_to_univnet_mel,
|
|
||||||
)
|
|
||||||
from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
|
from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
|
||||||
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
|
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
|
||||||
from TTS.tts.layers.tortoise.clvp import CLVP
|
from TTS.tts.layers.tortoise.clvp import CLVP
|
||||||
|
@ -25,7 +19,7 @@ from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_sc
|
||||||
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
||||||
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
|
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
|
||||||
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
|
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
|
||||||
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
|
from TTS.tts.layers.tortoise.utils import get_model_path
|
||||||
from TTS.tts.layers.tortoise.vocoder import VocConf
|
from TTS.tts.layers.tortoise.vocoder import VocConf
|
||||||
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
|
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
@ -498,16 +492,11 @@ class Tortoise(BaseTTS):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
||||||
|
|
||||||
def synthesis(
|
def synthesis(self, text, config, speaker_id="lj", **kwargs):
|
||||||
self,
|
|
||||||
text,
|
|
||||||
config,
|
|
||||||
speaker_id="lj",
|
|
||||||
):
|
|
||||||
voice_samples, conditioning_latents = load_voice(speaker_id)
|
voice_samples, conditioning_latents = load_voice(speaker_id)
|
||||||
|
|
||||||
outputs = self.inference_with_config(
|
outputs = self.inference_with_config(
|
||||||
text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents
|
text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
return_dict = {
|
return_dict = {
|
||||||
|
@ -533,10 +522,47 @@ class Tortoise(BaseTTS):
|
||||||
"top_p": config.top_p,
|
"top_p": config.top_p,
|
||||||
"cond_free_k": config.cond_free_k,
|
"cond_free_k": config.cond_free_k,
|
||||||
"diffusion_temperature": config.diffusion_temperature,
|
"diffusion_temperature": config.diffusion_temperature,
|
||||||
"num_autoregressive_samples": config.num_autoregressive_samples,
|
|
||||||
"diffusion_iterations": config.diffusion_iterations,
|
|
||||||
"sampler": config.sampler,
|
"sampler": config.sampler,
|
||||||
}
|
}
|
||||||
|
# Presets are defined here.
|
||||||
|
presets = {
|
||||||
|
"single_sample": {
|
||||||
|
"num_autoregressive_samples": 8,
|
||||||
|
"diffusion_iterations": 10,
|
||||||
|
"sampler": "ddim",
|
||||||
|
},
|
||||||
|
"ultra_fast": {
|
||||||
|
"num_autoregressive_samples": 16,
|
||||||
|
"diffusion_iterations": 10,
|
||||||
|
"sampler": "ddim",
|
||||||
|
},
|
||||||
|
"ultra_fast_old": {
|
||||||
|
"num_autoregressive_samples": 16,
|
||||||
|
"diffusion_iterations": 30,
|
||||||
|
"cond_free": False,
|
||||||
|
},
|
||||||
|
"very_fast": {
|
||||||
|
"num_autoregressive_samples": 32,
|
||||||
|
"diffusion_iterations": 30,
|
||||||
|
"sampler": "dpm++2m",
|
||||||
|
},
|
||||||
|
"fast": {
|
||||||
|
"num_autoregressive_samples": 5,
|
||||||
|
"diffusion_iterations": 50,
|
||||||
|
"sampler": "ddim",
|
||||||
|
},
|
||||||
|
"fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
|
||||||
|
"standard": {
|
||||||
|
"num_autoregressive_samples": 5,
|
||||||
|
"diffusion_iterations": 200,
|
||||||
|
},
|
||||||
|
"high_quality": {
|
||||||
|
"num_autoregressive_samples": 256,
|
||||||
|
"diffusion_iterations": 400,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
settings.update(presets[kwargs["preset"]])
|
||||||
|
kwargs.pop("preset")
|
||||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||||
return self.inference(text, **settings)
|
return self.inference(text, **settings)
|
||||||
|
|
||||||
|
@ -551,7 +577,7 @@ class Tortoise(BaseTTS):
|
||||||
return_deterministic_state=False,
|
return_deterministic_state=False,
|
||||||
latent_averaging_mode=0,
|
latent_averaging_mode=0,
|
||||||
# autoregressive generation parameters follow
|
# autoregressive generation parameters follow
|
||||||
num_autoregressive_samples=512,
|
num_autoregressive_samples=16,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
length_penalty=1,
|
length_penalty=1,
|
||||||
repetition_penalty=2.0,
|
repetition_penalty=2.0,
|
||||||
|
|
|
@ -228,6 +228,7 @@ class Synthesizer(object):
|
||||||
style_text=None,
|
style_text=None,
|
||||||
reference_wav=None,
|
reference_wav=None,
|
||||||
reference_speaker_name=None,
|
reference_speaker_name=None,
|
||||||
|
**kwargs,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""🐸 TTS magic. Run all the models and generate speech.
|
"""🐸 TTS magic. Run all the models and generate speech.
|
||||||
|
|
||||||
|
@ -328,7 +329,7 @@ class Synthesizer(object):
|
||||||
if not reference_wav:
|
if not reference_wav:
|
||||||
for sen in sens:
|
for sen in sens:
|
||||||
if self.tts_config.model == "tortoise":
|
if self.tts_config.model == "tortoise":
|
||||||
outputs = self.tts_model.synthesis(text=sen, config=self.tts_config)
|
outputs = self.tts_model.synthesis(text=sen, config=self.tts_config, **kwargs)
|
||||||
else:
|
else:
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
outputs = synthesis(
|
outputs = synthesis(
|
||||||
|
|
Loading…
Reference in New Issue