mirror of https://github.com/coqui-ai/TTS.git
added tortoise config and updated config and args, refactoring the code
This commit is contained in:
parent
e1838617d4
commit
1f77cc6cac
|
@ -0,0 +1,46 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.models.tortoise import TortoiseArgs, TortoiseAudioConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class TortoiseConfig(BaseTTSConfig):
|
||||
"""Defines parameters for Tortoise TTS model.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name. Do not change unless you know what you are doing.
|
||||
|
||||
model_args (TortoiseArgs):
|
||||
Model architecture arguments. Defaults to `TortoiseArgs()`.
|
||||
|
||||
audio (TortoiseAudioConfig):
|
||||
Audio processing configuration. Defaults to `TortoiseAudioConfig()`.
|
||||
Note:
|
||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.configs.vits_config import VitsConfig
|
||||
>>> config = VitsConfig()
|
||||
"""
|
||||
|
||||
model: str = "tortoise"
|
||||
# model specific params
|
||||
model_args: TortoiseArgs = field(default_factory=TortoiseArgs)
|
||||
audio: TortoiseAudioConfig = TortoiseAudioConfig()
|
||||
|
||||
# settings
|
||||
temperature: int = 0.2
|
||||
length_penalty: int = 1.0
|
||||
repetition_penalty: int = 2.0
|
||||
top_p: int = 0.8
|
||||
cond_free_k: int = 2.0
|
||||
diffusion_temperature: int = 1.0
|
||||
|
||||
# inference params
|
||||
num_autoregressive_samples: int = 16
|
||||
diffusion_iterations: int = 30
|
||||
sampler: str = "ddim"
|
|
@ -1,17 +1,23 @@
|
|||
# ## AGPL: a notification must be added stating that changes have been made to that file.
|
||||
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
||||
from TTS.tts.layers.tortoise.audio_utils import (
|
||||
denormalize_tacotron_mel,
|
||||
load_audio,
|
||||
load_voice,
|
||||
load_voices,
|
||||
wav_to_univnet_mel,
|
||||
)
|
||||
from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
|
||||
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
|
||||
from TTS.tts.layers.tortoise.clvp import CLVP
|
||||
|
@ -22,6 +28,7 @@ from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
|
|||
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
|
||||
from TTS.tts.layers.tortoise.vocoder import VocConf
|
||||
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
|
||||
|
||||
def pad_or_truncate(t, length):
|
||||
|
@ -191,21 +198,74 @@ def pick_best_batch_size_for_gpu():
|
|||
return batch_size
|
||||
|
||||
|
||||
class TextToSpeech:
|
||||
@dataclass
|
||||
class TortoiseAudioConfig(Coqpit):
|
||||
sample_rate: int = 22050
|
||||
diffusion_sample_rate: int = 24000
|
||||
|
||||
|
||||
@dataclass
|
||||
class TortoiseArgs(Coqpit):
|
||||
autoregressive_batch_size: int = 1
|
||||
enable_redaction: bool = True
|
||||
high_vram: bool = False
|
||||
kv_cache: bool = True
|
||||
ar_checkpoint: str = None
|
||||
clvp_checkpoint: str = None
|
||||
diff_checkpoint: str = None
|
||||
num_chars: int = 255
|
||||
|
||||
# UnifiedVoice params
|
||||
ar_max_mel_tokens: int = 604
|
||||
ar_max_text_tokens: int = 402
|
||||
ar_max_conditioning_inputs: int = 2
|
||||
ar_layers: int = 30
|
||||
ar_model_dim: int = 1024
|
||||
ar_heads: int = 16
|
||||
ar_number_text_tokens: int = 255
|
||||
ar_start_text_token: int = 255
|
||||
ar_checkpointing: bool = False
|
||||
ar_train_solo_embeddings: bool = False
|
||||
|
||||
# DiffTTS params
|
||||
diff_model_channels: int = 1024
|
||||
diff_num_layers: int = 10
|
||||
diff_in_channels: int = 100
|
||||
diff_out_channels: int = 200
|
||||
diff_in_latent_channels: int = 1024
|
||||
diff_in_tokens: int = 8193
|
||||
diff_dropout: int = 0
|
||||
diff_use_fp16: bool = False
|
||||
diff_num_heads: int = 16
|
||||
diff_layer_drop: int = 0
|
||||
diff_unconditioned_percentage: int = 0
|
||||
|
||||
# clvp params
|
||||
clvp_dim_text: int = 768
|
||||
clvp_dim_speech: int = 768
|
||||
clvp_dim_latent: int = 768
|
||||
clvp_num_text_tokens: int = 256
|
||||
clvp_text_enc_depth: int = 20
|
||||
clvp_text_seq_len: int = 350
|
||||
clvp_text_heads: int = 12
|
||||
clvp_num_speech_tokens: int = 8192
|
||||
clvp_speech_enc_depth: int = 20
|
||||
clvp_speech_heads: int = 12
|
||||
clvp_speech_seq_len: int = 430
|
||||
clvp_use_xformers: bool = True
|
||||
# constants
|
||||
duration_const: int = 102400
|
||||
|
||||
|
||||
class Tortoise(BaseTTS):
|
||||
"""
|
||||
Main entry point into Tortoise.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
autoregressive_batch_size=None,
|
||||
config: Coqpit,
|
||||
models_dir=MODELS_DIR,
|
||||
enable_redaction=True,
|
||||
high_vram=False,
|
||||
kv_cache=True,
|
||||
ar_checkpoint=None,
|
||||
clvp_checkpoint=None,
|
||||
diff_checkpoint=None,
|
||||
vocoder=VocConf.Univnet,
|
||||
):
|
||||
"""
|
||||
|
@ -224,82 +284,82 @@ class TextToSpeech:
|
|||
:param clvp_checkpoint: Path to a checkpoint file for the CLVP model. If omitted, uses default
|
||||
:param diff_checkpoint: Path to a checkpoint file for the diffusion model. If omitted, uses default
|
||||
"""
|
||||
self.ar_checkpoint = ar_checkpoint
|
||||
self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed
|
||||
super().__init__(config, ap=None, tokenizer=None)
|
||||
self.config = config
|
||||
self.ar_checkpoint = self.args.ar_checkpoint
|
||||
self.diff_checkpoint = self.args.diff_checkpoint # TODO: check if this is even needed
|
||||
self.models_dir = models_dir
|
||||
self.autoregressive_batch_size = (
|
||||
pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
||||
pick_best_batch_size_for_gpu()
|
||||
if self.args.autoregressive_batch_size is None
|
||||
else self.args.autoregressive_batch_size
|
||||
)
|
||||
self.enable_redaction = enable_redaction
|
||||
self.enable_redaction = self.args.enable_redaction
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if self.enable_redaction:
|
||||
self.aligner = Wav2VecAlignment()
|
||||
|
||||
self.tokenizer = VoiceBpeTokenizer()
|
||||
|
||||
if os.path.exists(f"{models_dir}/autoregressive.ptt"):
|
||||
# Assume this is a traced directory.
|
||||
self.autoregressive = torch.jit.load(f"{models_dir}/autoregressive.ptt")
|
||||
self.diffusion = torch.jit.load(f"{models_dir}/diffusion_decoder.ptt")
|
||||
else:
|
||||
self.autoregressive = (
|
||||
UnifiedVoice(
|
||||
max_mel_tokens=604,
|
||||
max_text_tokens=402,
|
||||
max_conditioning_inputs=2,
|
||||
layers=30,
|
||||
model_dim=1024,
|
||||
heads=16,
|
||||
number_text_tokens=255,
|
||||
start_text_token=255,
|
||||
checkpointing=False,
|
||||
train_solo_embeddings=False,
|
||||
)
|
||||
.cpu()
|
||||
.eval()
|
||||
)
|
||||
ar_path = ar_checkpoint or get_model_path("autoregressive.pth", models_dir)
|
||||
self.autoregressive.load_state_dict(torch.load(ar_path))
|
||||
self.autoregressive.post_init_gpt2_config(kv_cache)
|
||||
|
||||
diff_path = diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir)
|
||||
self.diffusion = (
|
||||
DiffusionTts(
|
||||
model_channels=1024,
|
||||
num_layers=10,
|
||||
in_channels=100,
|
||||
out_channels=200,
|
||||
in_latent_channels=1024,
|
||||
in_tokens=8193,
|
||||
dropout=0,
|
||||
use_fp16=False,
|
||||
num_heads=16,
|
||||
layer_drop=0,
|
||||
unconditioned_percentage=0,
|
||||
)
|
||||
.cpu()
|
||||
.eval()
|
||||
)
|
||||
self.diffusion.load_state_dict(torch.load(diff_path))
|
||||
self.clvp = (
|
||||
CLVP(
|
||||
dim_text=768,
|
||||
dim_speech=768,
|
||||
dim_latent=768,
|
||||
num_text_tokens=256,
|
||||
text_enc_depth=20,
|
||||
text_seq_len=350,
|
||||
text_heads=12,
|
||||
num_speech_tokens=8192,
|
||||
speech_enc_depth=20,
|
||||
speech_heads=12,
|
||||
speech_seq_len=430,
|
||||
use_xformers=True,
|
||||
self.autoregressive = (
|
||||
UnifiedVoice(
|
||||
max_mel_tokens=self.args.ar_max_mel_tokens,
|
||||
max_text_tokens=self.args.ar_max_text_tokens,
|
||||
max_conditioning_inputs=self.args.ar_max_conditioning_inputs,
|
||||
layers=self.args.ar_layers,
|
||||
model_dim=self.args.ar_model_dim,
|
||||
heads=self.args.ar_heads,
|
||||
number_text_tokens=self.args.ar_number_text_tokens,
|
||||
start_text_token=self.args.ar_start_text_token,
|
||||
checkpointing=self.args.ar_checkpointing,
|
||||
train_solo_embeddings=self.args.ar_train_solo_embeddings,
|
||||
)
|
||||
.cpu()
|
||||
.eval()
|
||||
)
|
||||
clvp_path = clvp_checkpoint or get_model_path("clvp2.pth", models_dir)
|
||||
ar_path = self.args.ar_checkpoint or get_model_path("autoregressive.pth", models_dir)
|
||||
self.autoregressive.load_state_dict(torch.load(ar_path))
|
||||
self.autoregressive.post_init_gpt2_config(self.args.kv_cache)
|
||||
|
||||
diff_path = self.args.diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir)
|
||||
self.diffusion = (
|
||||
DiffusionTts(
|
||||
model_channels=self.args.diff_model_channels,
|
||||
num_layers=self.args.diff_num_layers,
|
||||
in_channels=self.args.diff_in_channels,
|
||||
out_channels=self.args.diff_out_channels,
|
||||
in_latent_channels=self.args.diff_in_latent_channels,
|
||||
in_tokens=self.args.diff_in_tokens,
|
||||
dropout=self.args.diff_dropout,
|
||||
use_fp16=self.args.diff_use_fp16,
|
||||
num_heads=self.args.diff_num_heads,
|
||||
layer_drop=self.args.diff_layer_drop,
|
||||
unconditioned_percentage=self.args.diff_unconditioned_percentage,
|
||||
)
|
||||
.cpu()
|
||||
.eval()
|
||||
)
|
||||
self.diffusion.load_state_dict(torch.load(diff_path))
|
||||
|
||||
self.clvp = (
|
||||
CLVP(
|
||||
dim_text=self.args.clvp_dim_text,
|
||||
dim_speech=self.args.clvp_dim_speech,
|
||||
dim_latent=self.args.clvp_dim_latent,
|
||||
num_text_tokens=self.args.clvp_num_text_tokens,
|
||||
text_enc_depth=self.args.clvp_text_enc_depth,
|
||||
text_seq_len=self.args.clvp_text_seq_len,
|
||||
text_heads=self.args.clvp_text_heads,
|
||||
num_speech_tokens=self.args.clvp_num_speech_tokens,
|
||||
speech_enc_depth=self.args.clvp_speech_enc_depth,
|
||||
speech_heads=self.args.clvp_speech_heads,
|
||||
speech_seq_len=self.args.clvp_speech_seq_len,
|
||||
use_xformers=self.args.clvp_use_xformers,
|
||||
)
|
||||
.cpu()
|
||||
.eval()
|
||||
)
|
||||
clvp_path = self.args.clvp_checkpoint or get_model_path("clvp2.pth", models_dir)
|
||||
self.clvp.load_state_dict(torch.load(clvp_path))
|
||||
|
||||
self.vocoder = vocoder.value.constructor().cpu()
|
||||
|
@ -317,12 +377,12 @@ class TextToSpeech:
|
|||
self.rlg_auto = None
|
||||
self.rlg_diffusion = None
|
||||
|
||||
if high_vram:
|
||||
if self.args.high_vram:
|
||||
self.autoregressive = self.autoregressive.to(self.device)
|
||||
self.diffusion = self.diffusion.to(self.device)
|
||||
self.clvp = self.clvp.to(self.device)
|
||||
self.vocoder = self.vocoder.to(self.device)
|
||||
self.high_vram = high_vram
|
||||
self.high_vram = self.args.high_vram
|
||||
|
||||
@contextmanager
|
||||
def temporary_cuda(self, model):
|
||||
|
@ -355,7 +415,7 @@ class TextToSpeech:
|
|||
1,
|
||||
2,
|
||||
], "latent_averaging mode has to be one of (0, 1, 2)"
|
||||
print("mode", latent_averaging_mode)
|
||||
|
||||
with torch.no_grad():
|
||||
voice_samples = [[v.to(self.device) for v in ls] for ls in voice_samples]
|
||||
|
||||
|
@ -368,7 +428,7 @@ class TextToSpeech:
|
|||
|
||||
diffusion_conds = []
|
||||
|
||||
DURS_CONST = 102400
|
||||
DURS_CONST = self.args.duration_const
|
||||
for ls in voice_samples:
|
||||
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
||||
sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1]
|
||||
|
@ -428,67 +488,51 @@ class TextToSpeech:
|
|||
with torch.no_grad():
|
||||
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
||||
|
||||
def tts_with_preset(self, text, preset="fast", **kwargs):
|
||||
def synthesis(
|
||||
self,
|
||||
model,
|
||||
text,
|
||||
config,
|
||||
preset,
|
||||
speaker_id="lj",
|
||||
):
|
||||
voice_samples, conditioning_latents = load_voice(speaker_id)
|
||||
|
||||
outputs = model.inference_with_config(
|
||||
text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents
|
||||
)
|
||||
|
||||
return_dict = {
|
||||
"wav": outputs["wav"],
|
||||
"deterministic_seed": outputs["deterministic_seed"],
|
||||
"text_inputs": outputs["text"],
|
||||
"voice_samples": outputs["voice_samples"],
|
||||
"conditioning_latents": outputs["conditioning_latents"],
|
||||
}
|
||||
|
||||
return return_dict
|
||||
|
||||
def inference_with_config(self, text, config, **kwargs):
|
||||
"""
|
||||
Calls TTS with one of a set of preset generation parameters. Options:
|
||||
'single_sample': Produces speech even faster, but only produces 1 sample.
|
||||
'ultra_fast': Produces speech much faster than the original tortoise repo.
|
||||
'ultra_fast_old': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest).
|
||||
'fast': Decent quality speech at a decent inference rate. A good choice for mass inference.
|
||||
'standard': Very good quality. This is generally about as good as you are going to get.
|
||||
'high_quality': Use if you want the absolute best. This is not really worth the compute, though.
|
||||
inference with config
|
||||
#TODO describe in detail
|
||||
"""
|
||||
# Use generally found best tuning knobs for generation.
|
||||
settings = {
|
||||
"temperature": 0.2,
|
||||
"length_penalty": 1.0,
|
||||
"repetition_penalty": 2.0,
|
||||
"top_p": 0.8,
|
||||
"cond_free_k": 2.0,
|
||||
"diffusion_temperature": 1.0,
|
||||
"temperature": config.temperature,
|
||||
"length_penalty": config.length_penalty,
|
||||
"repetition_penalty": config.repetition_penalty,
|
||||
"top_p": config.top_p,
|
||||
"cond_free_k": config.cond_free_k,
|
||||
"diffusion_temperature": config.diffusion_temperature,
|
||||
"num_autoregressive_samples": config.num_autoregressive_samples,
|
||||
"diffusion_iterations": config.diffusion_iterations,
|
||||
"sampler": config.sampler,
|
||||
}
|
||||
# Presets are defined here.
|
||||
presets = {
|
||||
"single_sample": {
|
||||
"num_autoregressive_samples": 8,
|
||||
"diffusion_iterations": 10,
|
||||
"sampler": "ddim",
|
||||
},
|
||||
"ultra_fast": {
|
||||
"num_autoregressive_samples": 16,
|
||||
"diffusion_iterations": 10,
|
||||
"sampler": "ddim",
|
||||
},
|
||||
"ultra_fast_old": {
|
||||
"num_autoregressive_samples": 16,
|
||||
"diffusion_iterations": 30,
|
||||
"cond_free": False,
|
||||
},
|
||||
"very_fast": {
|
||||
"num_autoregressive_samples": 32,
|
||||
"diffusion_iterations": 30,
|
||||
"sampler": "dpm++2m",
|
||||
},
|
||||
"fast": {
|
||||
"num_autoregressive_samples": 16,
|
||||
"diffusion_iterations": 50,
|
||||
"sampler": "ddim",
|
||||
},
|
||||
"fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
|
||||
"standard": {
|
||||
"num_autoregressive_samples": 256,
|
||||
"diffusion_iterations": 200,
|
||||
},
|
||||
"high_quality": {
|
||||
"num_autoregressive_samples": 256,
|
||||
"diffusion_iterations": 400,
|
||||
},
|
||||
}
|
||||
settings.update(presets[preset])
|
||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
return self.tts(text, **settings)
|
||||
return self.inference(text, **settings)
|
||||
|
||||
def tts(
|
||||
def inference(
|
||||
self,
|
||||
text,
|
||||
voice_samples=None,
|
||||
|
@ -709,11 +753,34 @@ class TextToSpeech:
|
|||
else:
|
||||
res = wav_candidates[0]
|
||||
|
||||
if return_deterministic_state:
|
||||
return res, (
|
||||
deterministic_seed,
|
||||
text,
|
||||
voice_samples,
|
||||
conditioning_latents,
|
||||
)
|
||||
return res
|
||||
return_dict = {
|
||||
"wav": res,
|
||||
"deterministic_seed": None,
|
||||
"text": None,
|
||||
"voice_samples": None,
|
||||
"conditioning_latents": None,
|
||||
}
|
||||
if return_deterministic_state:
|
||||
return_dict = {
|
||||
"wav": res,
|
||||
"deterministic_seed": deterministic_seed,
|
||||
"text": text,
|
||||
"voice_samples": voice_samples,
|
||||
"conditioning_latents": conditioning_latents,
|
||||
}
|
||||
return return_dict
|
||||
|
||||
def forward():
|
||||
raise NotImplementedError("Tortoise Training is not implemented")
|
||||
|
||||
def eval_step():
|
||||
raise NotImplementedError("Tortoise Training is not implemented")
|
||||
|
||||
def init_from_config():
|
||||
raise NotImplementedError("Tortoise Training is not implemented")
|
||||
|
||||
def load_checkpoint():
|
||||
raise NotImplementedError("Tortoise Training is not implemented")
|
||||
|
||||
def train_step():
|
||||
raise NotImplementedError("Tortoise Training is not implemented")
|
||||
|
|
Loading…
Reference in New Issue