added tortoise config and updated config and args, refactoring the code

This commit is contained in:
manmay-nakhashi 2023-04-24 01:10:52 +05:30
parent e1838617d4
commit 1f77cc6cac
3 changed files with 253 additions and 140 deletions

View File

@ -0,0 +1,46 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.tortoise import TortoiseArgs, TortoiseAudioConfig
@dataclass
class TortoiseConfig(BaseTTSConfig):
"""Defines parameters for Tortoise TTS model.
Args:
model (str):
Model name. Do not change unless you know what you are doing.
model_args (TortoiseArgs):
Model architecture arguments. Defaults to `TortoiseArgs()`.
audio (TortoiseAudioConfig):
Audio processing configuration. Defaults to `TortoiseAudioConfig()`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
Example:
>>> from TTS.tts.configs.vits_config import VitsConfig
>>> config = VitsConfig()
"""
model: str = "tortoise"
# model specific params
model_args: TortoiseArgs = field(default_factory=TortoiseArgs)
audio: TortoiseAudioConfig = TortoiseAudioConfig()
# settings
temperature: int = 0.2
length_penalty: int = 1.0
repetition_penalty: int = 2.0
top_p: int = 0.8
cond_free_k: int = 2.0
diffusion_temperature: int = 1.0
# inference params
num_autoregressive_samples: int = 16
diffusion_iterations: int = 30
sampler: str = "ddim"

View File

@ -1,17 +1,23 @@
# ## AGPL: a notification must be added stating that changes have been made to that file.
import os
import random
from contextlib import contextmanager
from dataclasses import dataclass, field
from time import time
import torch
import torch.nn.functional as F
import torchaudio
from coqpit import Coqpit
from tqdm import tqdm
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
from TTS.tts.layers.tortoise.audio_utils import (
denormalize_tacotron_mel,
load_audio,
load_voice,
load_voices,
wav_to_univnet_mel,
)
from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
from TTS.tts.layers.tortoise.clvp import CLVP
@ -22,6 +28,7 @@ from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
from TTS.tts.layers.tortoise.vocoder import VocConf
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS
def pad_or_truncate(t, length):
@ -191,21 +198,74 @@ def pick_best_batch_size_for_gpu():
return batch_size
class TextToSpeech:
@dataclass
class TortoiseAudioConfig(Coqpit):
sample_rate: int = 22050
diffusion_sample_rate: int = 24000
@dataclass
class TortoiseArgs(Coqpit):
autoregressive_batch_size: int = 1
enable_redaction: bool = True
high_vram: bool = False
kv_cache: bool = True
ar_checkpoint: str = None
clvp_checkpoint: str = None
diff_checkpoint: str = None
num_chars: int = 255
# UnifiedVoice params
ar_max_mel_tokens: int = 604
ar_max_text_tokens: int = 402
ar_max_conditioning_inputs: int = 2
ar_layers: int = 30
ar_model_dim: int = 1024
ar_heads: int = 16
ar_number_text_tokens: int = 255
ar_start_text_token: int = 255
ar_checkpointing: bool = False
ar_train_solo_embeddings: bool = False
# DiffTTS params
diff_model_channels: int = 1024
diff_num_layers: int = 10
diff_in_channels: int = 100
diff_out_channels: int = 200
diff_in_latent_channels: int = 1024
diff_in_tokens: int = 8193
diff_dropout: int = 0
diff_use_fp16: bool = False
diff_num_heads: int = 16
diff_layer_drop: int = 0
diff_unconditioned_percentage: int = 0
# clvp params
clvp_dim_text: int = 768
clvp_dim_speech: int = 768
clvp_dim_latent: int = 768
clvp_num_text_tokens: int = 256
clvp_text_enc_depth: int = 20
clvp_text_seq_len: int = 350
clvp_text_heads: int = 12
clvp_num_speech_tokens: int = 8192
clvp_speech_enc_depth: int = 20
clvp_speech_heads: int = 12
clvp_speech_seq_len: int = 430
clvp_use_xformers: bool = True
# constants
duration_const: int = 102400
class Tortoise(BaseTTS):
"""
Main entry point into Tortoise.
"""
def __init__(
self,
autoregressive_batch_size=None,
config: Coqpit,
models_dir=MODELS_DIR,
enable_redaction=True,
high_vram=False,
kv_cache=True,
ar_checkpoint=None,
clvp_checkpoint=None,
diff_checkpoint=None,
vocoder=VocConf.Univnet,
):
"""
@ -224,82 +284,82 @@ class TextToSpeech:
:param clvp_checkpoint: Path to a checkpoint file for the CLVP model. If omitted, uses default
:param diff_checkpoint: Path to a checkpoint file for the diffusion model. If omitted, uses default
"""
self.ar_checkpoint = ar_checkpoint
self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed
super().__init__(config, ap=None, tokenizer=None)
self.config = config
self.ar_checkpoint = self.args.ar_checkpoint
self.diff_checkpoint = self.args.diff_checkpoint # TODO: check if this is even needed
self.models_dir = models_dir
self.autoregressive_batch_size = (
pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
pick_best_batch_size_for_gpu()
if self.args.autoregressive_batch_size is None
else self.args.autoregressive_batch_size
)
self.enable_redaction = enable_redaction
self.enable_redaction = self.args.enable_redaction
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.enable_redaction:
self.aligner = Wav2VecAlignment()
self.tokenizer = VoiceBpeTokenizer()
if os.path.exists(f"{models_dir}/autoregressive.ptt"):
# Assume this is a traced directory.
self.autoregressive = torch.jit.load(f"{models_dir}/autoregressive.ptt")
self.diffusion = torch.jit.load(f"{models_dir}/diffusion_decoder.ptt")
else:
self.autoregressive = (
UnifiedVoice(
max_mel_tokens=604,
max_text_tokens=402,
max_conditioning_inputs=2,
layers=30,
model_dim=1024,
heads=16,
number_text_tokens=255,
start_text_token=255,
checkpointing=False,
train_solo_embeddings=False,
)
.cpu()
.eval()
)
ar_path = ar_checkpoint or get_model_path("autoregressive.pth", models_dir)
self.autoregressive.load_state_dict(torch.load(ar_path))
self.autoregressive.post_init_gpt2_config(kv_cache)
diff_path = diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir)
self.diffusion = (
DiffusionTts(
model_channels=1024,
num_layers=10,
in_channels=100,
out_channels=200,
in_latent_channels=1024,
in_tokens=8193,
dropout=0,
use_fp16=False,
num_heads=16,
layer_drop=0,
unconditioned_percentage=0,
)
.cpu()
.eval()
)
self.diffusion.load_state_dict(torch.load(diff_path))
self.clvp = (
CLVP(
dim_text=768,
dim_speech=768,
dim_latent=768,
num_text_tokens=256,
text_enc_depth=20,
text_seq_len=350,
text_heads=12,
num_speech_tokens=8192,
speech_enc_depth=20,
speech_heads=12,
speech_seq_len=430,
use_xformers=True,
self.autoregressive = (
UnifiedVoice(
max_mel_tokens=self.args.ar_max_mel_tokens,
max_text_tokens=self.args.ar_max_text_tokens,
max_conditioning_inputs=self.args.ar_max_conditioning_inputs,
layers=self.args.ar_layers,
model_dim=self.args.ar_model_dim,
heads=self.args.ar_heads,
number_text_tokens=self.args.ar_number_text_tokens,
start_text_token=self.args.ar_start_text_token,
checkpointing=self.args.ar_checkpointing,
train_solo_embeddings=self.args.ar_train_solo_embeddings,
)
.cpu()
.eval()
)
clvp_path = clvp_checkpoint or get_model_path("clvp2.pth", models_dir)
ar_path = self.args.ar_checkpoint or get_model_path("autoregressive.pth", models_dir)
self.autoregressive.load_state_dict(torch.load(ar_path))
self.autoregressive.post_init_gpt2_config(self.args.kv_cache)
diff_path = self.args.diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir)
self.diffusion = (
DiffusionTts(
model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers,
in_channels=self.args.diff_in_channels,
out_channels=self.args.diff_out_channels,
in_latent_channels=self.args.diff_in_latent_channels,
in_tokens=self.args.diff_in_tokens,
dropout=self.args.diff_dropout,
use_fp16=self.args.diff_use_fp16,
num_heads=self.args.diff_num_heads,
layer_drop=self.args.diff_layer_drop,
unconditioned_percentage=self.args.diff_unconditioned_percentage,
)
.cpu()
.eval()
)
self.diffusion.load_state_dict(torch.load(diff_path))
self.clvp = (
CLVP(
dim_text=self.args.clvp_dim_text,
dim_speech=self.args.clvp_dim_speech,
dim_latent=self.args.clvp_dim_latent,
num_text_tokens=self.args.clvp_num_text_tokens,
text_enc_depth=self.args.clvp_text_enc_depth,
text_seq_len=self.args.clvp_text_seq_len,
text_heads=self.args.clvp_text_heads,
num_speech_tokens=self.args.clvp_num_speech_tokens,
speech_enc_depth=self.args.clvp_speech_enc_depth,
speech_heads=self.args.clvp_speech_heads,
speech_seq_len=self.args.clvp_speech_seq_len,
use_xformers=self.args.clvp_use_xformers,
)
.cpu()
.eval()
)
clvp_path = self.args.clvp_checkpoint or get_model_path("clvp2.pth", models_dir)
self.clvp.load_state_dict(torch.load(clvp_path))
self.vocoder = vocoder.value.constructor().cpu()
@ -317,12 +377,12 @@ class TextToSpeech:
self.rlg_auto = None
self.rlg_diffusion = None
if high_vram:
if self.args.high_vram:
self.autoregressive = self.autoregressive.to(self.device)
self.diffusion = self.diffusion.to(self.device)
self.clvp = self.clvp.to(self.device)
self.vocoder = self.vocoder.to(self.device)
self.high_vram = high_vram
self.high_vram = self.args.high_vram
@contextmanager
def temporary_cuda(self, model):
@ -355,7 +415,7 @@ class TextToSpeech:
1,
2,
], "latent_averaging mode has to be one of (0, 1, 2)"
print("mode", latent_averaging_mode)
with torch.no_grad():
voice_samples = [[v.to(self.device) for v in ls] for ls in voice_samples]
@ -368,7 +428,7 @@ class TextToSpeech:
diffusion_conds = []
DURS_CONST = 102400
DURS_CONST = self.args.duration_const
for ls in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1]
@ -428,67 +488,51 @@ class TextToSpeech:
with torch.no_grad():
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
def tts_with_preset(self, text, preset="fast", **kwargs):
def synthesis(
self,
model,
text,
config,
preset,
speaker_id="lj",
):
voice_samples, conditioning_latents = load_voice(speaker_id)
outputs = model.inference_with_config(
text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents
)
return_dict = {
"wav": outputs["wav"],
"deterministic_seed": outputs["deterministic_seed"],
"text_inputs": outputs["text"],
"voice_samples": outputs["voice_samples"],
"conditioning_latents": outputs["conditioning_latents"],
}
return return_dict
def inference_with_config(self, text, config, **kwargs):
"""
Calls TTS with one of a set of preset generation parameters. Options:
'single_sample': Produces speech even faster, but only produces 1 sample.
'ultra_fast': Produces speech much faster than the original tortoise repo.
'ultra_fast_old': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest).
'fast': Decent quality speech at a decent inference rate. A good choice for mass inference.
'standard': Very good quality. This is generally about as good as you are going to get.
'high_quality': Use if you want the absolute best. This is not really worth the compute, though.
inference with config
#TODO describe in detail
"""
# Use generally found best tuning knobs for generation.
settings = {
"temperature": 0.2,
"length_penalty": 1.0,
"repetition_penalty": 2.0,
"top_p": 0.8,
"cond_free_k": 2.0,
"diffusion_temperature": 1.0,
"temperature": config.temperature,
"length_penalty": config.length_penalty,
"repetition_penalty": config.repetition_penalty,
"top_p": config.top_p,
"cond_free_k": config.cond_free_k,
"diffusion_temperature": config.diffusion_temperature,
"num_autoregressive_samples": config.num_autoregressive_samples,
"diffusion_iterations": config.diffusion_iterations,
"sampler": config.sampler,
}
# Presets are defined here.
presets = {
"single_sample": {
"num_autoregressive_samples": 8,
"diffusion_iterations": 10,
"sampler": "ddim",
},
"ultra_fast": {
"num_autoregressive_samples": 16,
"diffusion_iterations": 10,
"sampler": "ddim",
},
"ultra_fast_old": {
"num_autoregressive_samples": 16,
"diffusion_iterations": 30,
"cond_free": False,
},
"very_fast": {
"num_autoregressive_samples": 32,
"diffusion_iterations": 30,
"sampler": "dpm++2m",
},
"fast": {
"num_autoregressive_samples": 16,
"diffusion_iterations": 50,
"sampler": "ddim",
},
"fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
"standard": {
"num_autoregressive_samples": 256,
"diffusion_iterations": 200,
},
"high_quality": {
"num_autoregressive_samples": 256,
"diffusion_iterations": 400,
},
}
settings.update(presets[preset])
settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.tts(text, **settings)
return self.inference(text, **settings)
def tts(
def inference(
self,
text,
voice_samples=None,
@ -709,11 +753,34 @@ class TextToSpeech:
else:
res = wav_candidates[0]
if return_deterministic_state:
return res, (
deterministic_seed,
text,
voice_samples,
conditioning_latents,
)
return res
return_dict = {
"wav": res,
"deterministic_seed": None,
"text": None,
"voice_samples": None,
"conditioning_latents": None,
}
if return_deterministic_state:
return_dict = {
"wav": res,
"deterministic_seed": deterministic_seed,
"text": text,
"voice_samples": voice_samples,
"conditioning_latents": conditioning_latents,
}
return return_dict
def forward():
raise NotImplementedError("Tortoise Training is not implemented")
def eval_step():
raise NotImplementedError("Tortoise Training is not implemented")
def init_from_config():
raise NotImplementedError("Tortoise Training is not implemented")
def load_checkpoint():
raise NotImplementedError("Tortoise Training is not implemented")
def train_step():
raise NotImplementedError("Tortoise Training is not implemented")