mirror of https://github.com/coqui-ai/TTS.git
added tortoise config and updated config and args, refactoring the code
This commit is contained in:
parent
e1838617d4
commit
1f77cc6cac
|
@ -0,0 +1,46 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
|
from TTS.tts.models.tortoise import TortoiseArgs, TortoiseAudioConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TortoiseConfig(BaseTTSConfig):
|
||||||
|
"""Defines parameters for Tortoise TTS model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str):
|
||||||
|
Model name. Do not change unless you know what you are doing.
|
||||||
|
|
||||||
|
model_args (TortoiseArgs):
|
||||||
|
Model architecture arguments. Defaults to `TortoiseArgs()`.
|
||||||
|
|
||||||
|
audio (TortoiseAudioConfig):
|
||||||
|
Audio processing configuration. Defaults to `TortoiseAudioConfig()`.
|
||||||
|
Note:
|
||||||
|
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
>>> config = VitsConfig()
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str = "tortoise"
|
||||||
|
# model specific params
|
||||||
|
model_args: TortoiseArgs = field(default_factory=TortoiseArgs)
|
||||||
|
audio: TortoiseAudioConfig = TortoiseAudioConfig()
|
||||||
|
|
||||||
|
# settings
|
||||||
|
temperature: int = 0.2
|
||||||
|
length_penalty: int = 1.0
|
||||||
|
repetition_penalty: int = 2.0
|
||||||
|
top_p: int = 0.8
|
||||||
|
cond_free_k: int = 2.0
|
||||||
|
diffusion_temperature: int = 1.0
|
||||||
|
|
||||||
|
# inference params
|
||||||
|
num_autoregressive_samples: int = 16
|
||||||
|
diffusion_iterations: int = 30
|
||||||
|
sampler: str = "ddim"
|
|
@ -1,17 +1,23 @@
|
||||||
# ## AGPL: a notification must be added stating that changes have been made to that file.
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
from coqpit import Coqpit
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||||
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
from TTS.tts.layers.tortoise.audio_utils import (
|
||||||
|
denormalize_tacotron_mel,
|
||||||
|
load_audio,
|
||||||
|
load_voice,
|
||||||
|
load_voices,
|
||||||
|
wav_to_univnet_mel,
|
||||||
|
)
|
||||||
from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
|
from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
|
||||||
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
|
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
|
||||||
from TTS.tts.layers.tortoise.clvp import CLVP
|
from TTS.tts.layers.tortoise.clvp import CLVP
|
||||||
|
@ -22,6 +28,7 @@ from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
|
||||||
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
|
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
|
||||||
from TTS.tts.layers.tortoise.vocoder import VocConf
|
from TTS.tts.layers.tortoise.vocoder import VocConf
|
||||||
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
|
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
|
||||||
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
|
||||||
|
|
||||||
def pad_or_truncate(t, length):
|
def pad_or_truncate(t, length):
|
||||||
|
@ -191,21 +198,74 @@ def pick_best_batch_size_for_gpu():
|
||||||
return batch_size
|
return batch_size
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeech:
|
@dataclass
|
||||||
|
class TortoiseAudioConfig(Coqpit):
|
||||||
|
sample_rate: int = 22050
|
||||||
|
diffusion_sample_rate: int = 24000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TortoiseArgs(Coqpit):
|
||||||
|
autoregressive_batch_size: int = 1
|
||||||
|
enable_redaction: bool = True
|
||||||
|
high_vram: bool = False
|
||||||
|
kv_cache: bool = True
|
||||||
|
ar_checkpoint: str = None
|
||||||
|
clvp_checkpoint: str = None
|
||||||
|
diff_checkpoint: str = None
|
||||||
|
num_chars: int = 255
|
||||||
|
|
||||||
|
# UnifiedVoice params
|
||||||
|
ar_max_mel_tokens: int = 604
|
||||||
|
ar_max_text_tokens: int = 402
|
||||||
|
ar_max_conditioning_inputs: int = 2
|
||||||
|
ar_layers: int = 30
|
||||||
|
ar_model_dim: int = 1024
|
||||||
|
ar_heads: int = 16
|
||||||
|
ar_number_text_tokens: int = 255
|
||||||
|
ar_start_text_token: int = 255
|
||||||
|
ar_checkpointing: bool = False
|
||||||
|
ar_train_solo_embeddings: bool = False
|
||||||
|
|
||||||
|
# DiffTTS params
|
||||||
|
diff_model_channels: int = 1024
|
||||||
|
diff_num_layers: int = 10
|
||||||
|
diff_in_channels: int = 100
|
||||||
|
diff_out_channels: int = 200
|
||||||
|
diff_in_latent_channels: int = 1024
|
||||||
|
diff_in_tokens: int = 8193
|
||||||
|
diff_dropout: int = 0
|
||||||
|
diff_use_fp16: bool = False
|
||||||
|
diff_num_heads: int = 16
|
||||||
|
diff_layer_drop: int = 0
|
||||||
|
diff_unconditioned_percentage: int = 0
|
||||||
|
|
||||||
|
# clvp params
|
||||||
|
clvp_dim_text: int = 768
|
||||||
|
clvp_dim_speech: int = 768
|
||||||
|
clvp_dim_latent: int = 768
|
||||||
|
clvp_num_text_tokens: int = 256
|
||||||
|
clvp_text_enc_depth: int = 20
|
||||||
|
clvp_text_seq_len: int = 350
|
||||||
|
clvp_text_heads: int = 12
|
||||||
|
clvp_num_speech_tokens: int = 8192
|
||||||
|
clvp_speech_enc_depth: int = 20
|
||||||
|
clvp_speech_heads: int = 12
|
||||||
|
clvp_speech_seq_len: int = 430
|
||||||
|
clvp_use_xformers: bool = True
|
||||||
|
# constants
|
||||||
|
duration_const: int = 102400
|
||||||
|
|
||||||
|
|
||||||
|
class Tortoise(BaseTTS):
|
||||||
"""
|
"""
|
||||||
Main entry point into Tortoise.
|
Main entry point into Tortoise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
autoregressive_batch_size=None,
|
config: Coqpit,
|
||||||
models_dir=MODELS_DIR,
|
models_dir=MODELS_DIR,
|
||||||
enable_redaction=True,
|
|
||||||
high_vram=False,
|
|
||||||
kv_cache=True,
|
|
||||||
ar_checkpoint=None,
|
|
||||||
clvp_checkpoint=None,
|
|
||||||
diff_checkpoint=None,
|
|
||||||
vocoder=VocConf.Univnet,
|
vocoder=VocConf.Univnet,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -224,82 +284,82 @@ class TextToSpeech:
|
||||||
:param clvp_checkpoint: Path to a checkpoint file for the CLVP model. If omitted, uses default
|
:param clvp_checkpoint: Path to a checkpoint file for the CLVP model. If omitted, uses default
|
||||||
:param diff_checkpoint: Path to a checkpoint file for the diffusion model. If omitted, uses default
|
:param diff_checkpoint: Path to a checkpoint file for the diffusion model. If omitted, uses default
|
||||||
"""
|
"""
|
||||||
self.ar_checkpoint = ar_checkpoint
|
super().__init__(config, ap=None, tokenizer=None)
|
||||||
self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed
|
self.config = config
|
||||||
|
self.ar_checkpoint = self.args.ar_checkpoint
|
||||||
|
self.diff_checkpoint = self.args.diff_checkpoint # TODO: check if this is even needed
|
||||||
self.models_dir = models_dir
|
self.models_dir = models_dir
|
||||||
self.autoregressive_batch_size = (
|
self.autoregressive_batch_size = (
|
||||||
pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
pick_best_batch_size_for_gpu()
|
||||||
|
if self.args.autoregressive_batch_size is None
|
||||||
|
else self.args.autoregressive_batch_size
|
||||||
)
|
)
|
||||||
self.enable_redaction = enable_redaction
|
self.enable_redaction = self.args.enable_redaction
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
if self.enable_redaction:
|
if self.enable_redaction:
|
||||||
self.aligner = Wav2VecAlignment()
|
self.aligner = Wav2VecAlignment()
|
||||||
|
|
||||||
self.tokenizer = VoiceBpeTokenizer()
|
self.tokenizer = VoiceBpeTokenizer()
|
||||||
|
|
||||||
if os.path.exists(f"{models_dir}/autoregressive.ptt"):
|
|
||||||
# Assume this is a traced directory.
|
|
||||||
self.autoregressive = torch.jit.load(f"{models_dir}/autoregressive.ptt")
|
|
||||||
self.diffusion = torch.jit.load(f"{models_dir}/diffusion_decoder.ptt")
|
|
||||||
else:
|
|
||||||
self.autoregressive = (
|
self.autoregressive = (
|
||||||
UnifiedVoice(
|
UnifiedVoice(
|
||||||
max_mel_tokens=604,
|
max_mel_tokens=self.args.ar_max_mel_tokens,
|
||||||
max_text_tokens=402,
|
max_text_tokens=self.args.ar_max_text_tokens,
|
||||||
max_conditioning_inputs=2,
|
max_conditioning_inputs=self.args.ar_max_conditioning_inputs,
|
||||||
layers=30,
|
layers=self.args.ar_layers,
|
||||||
model_dim=1024,
|
model_dim=self.args.ar_model_dim,
|
||||||
heads=16,
|
heads=self.args.ar_heads,
|
||||||
number_text_tokens=255,
|
number_text_tokens=self.args.ar_number_text_tokens,
|
||||||
start_text_token=255,
|
start_text_token=self.args.ar_start_text_token,
|
||||||
checkpointing=False,
|
checkpointing=self.args.ar_checkpointing,
|
||||||
train_solo_embeddings=False,
|
train_solo_embeddings=self.args.ar_train_solo_embeddings,
|
||||||
)
|
)
|
||||||
.cpu()
|
.cpu()
|
||||||
.eval()
|
.eval()
|
||||||
)
|
)
|
||||||
ar_path = ar_checkpoint or get_model_path("autoregressive.pth", models_dir)
|
ar_path = self.args.ar_checkpoint or get_model_path("autoregressive.pth", models_dir)
|
||||||
self.autoregressive.load_state_dict(torch.load(ar_path))
|
self.autoregressive.load_state_dict(torch.load(ar_path))
|
||||||
self.autoregressive.post_init_gpt2_config(kv_cache)
|
self.autoregressive.post_init_gpt2_config(self.args.kv_cache)
|
||||||
|
|
||||||
diff_path = diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir)
|
diff_path = self.args.diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir)
|
||||||
self.diffusion = (
|
self.diffusion = (
|
||||||
DiffusionTts(
|
DiffusionTts(
|
||||||
model_channels=1024,
|
model_channels=self.args.diff_model_channels,
|
||||||
num_layers=10,
|
num_layers=self.args.diff_num_layers,
|
||||||
in_channels=100,
|
in_channels=self.args.diff_in_channels,
|
||||||
out_channels=200,
|
out_channels=self.args.diff_out_channels,
|
||||||
in_latent_channels=1024,
|
in_latent_channels=self.args.diff_in_latent_channels,
|
||||||
in_tokens=8193,
|
in_tokens=self.args.diff_in_tokens,
|
||||||
dropout=0,
|
dropout=self.args.diff_dropout,
|
||||||
use_fp16=False,
|
use_fp16=self.args.diff_use_fp16,
|
||||||
num_heads=16,
|
num_heads=self.args.diff_num_heads,
|
||||||
layer_drop=0,
|
layer_drop=self.args.diff_layer_drop,
|
||||||
unconditioned_percentage=0,
|
unconditioned_percentage=self.args.diff_unconditioned_percentage,
|
||||||
)
|
)
|
||||||
.cpu()
|
.cpu()
|
||||||
.eval()
|
.eval()
|
||||||
)
|
)
|
||||||
self.diffusion.load_state_dict(torch.load(diff_path))
|
self.diffusion.load_state_dict(torch.load(diff_path))
|
||||||
|
|
||||||
self.clvp = (
|
self.clvp = (
|
||||||
CLVP(
|
CLVP(
|
||||||
dim_text=768,
|
dim_text=self.args.clvp_dim_text,
|
||||||
dim_speech=768,
|
dim_speech=self.args.clvp_dim_speech,
|
||||||
dim_latent=768,
|
dim_latent=self.args.clvp_dim_latent,
|
||||||
num_text_tokens=256,
|
num_text_tokens=self.args.clvp_num_text_tokens,
|
||||||
text_enc_depth=20,
|
text_enc_depth=self.args.clvp_text_enc_depth,
|
||||||
text_seq_len=350,
|
text_seq_len=self.args.clvp_text_seq_len,
|
||||||
text_heads=12,
|
text_heads=self.args.clvp_text_heads,
|
||||||
num_speech_tokens=8192,
|
num_speech_tokens=self.args.clvp_num_speech_tokens,
|
||||||
speech_enc_depth=20,
|
speech_enc_depth=self.args.clvp_speech_enc_depth,
|
||||||
speech_heads=12,
|
speech_heads=self.args.clvp_speech_heads,
|
||||||
speech_seq_len=430,
|
speech_seq_len=self.args.clvp_speech_seq_len,
|
||||||
use_xformers=True,
|
use_xformers=self.args.clvp_use_xformers,
|
||||||
)
|
)
|
||||||
.cpu()
|
.cpu()
|
||||||
.eval()
|
.eval()
|
||||||
)
|
)
|
||||||
clvp_path = clvp_checkpoint or get_model_path("clvp2.pth", models_dir)
|
clvp_path = self.args.clvp_checkpoint or get_model_path("clvp2.pth", models_dir)
|
||||||
self.clvp.load_state_dict(torch.load(clvp_path))
|
self.clvp.load_state_dict(torch.load(clvp_path))
|
||||||
|
|
||||||
self.vocoder = vocoder.value.constructor().cpu()
|
self.vocoder = vocoder.value.constructor().cpu()
|
||||||
|
@ -317,12 +377,12 @@ class TextToSpeech:
|
||||||
self.rlg_auto = None
|
self.rlg_auto = None
|
||||||
self.rlg_diffusion = None
|
self.rlg_diffusion = None
|
||||||
|
|
||||||
if high_vram:
|
if self.args.high_vram:
|
||||||
self.autoregressive = self.autoregressive.to(self.device)
|
self.autoregressive = self.autoregressive.to(self.device)
|
||||||
self.diffusion = self.diffusion.to(self.device)
|
self.diffusion = self.diffusion.to(self.device)
|
||||||
self.clvp = self.clvp.to(self.device)
|
self.clvp = self.clvp.to(self.device)
|
||||||
self.vocoder = self.vocoder.to(self.device)
|
self.vocoder = self.vocoder.to(self.device)
|
||||||
self.high_vram = high_vram
|
self.high_vram = self.args.high_vram
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def temporary_cuda(self, model):
|
def temporary_cuda(self, model):
|
||||||
|
@ -355,7 +415,7 @@ class TextToSpeech:
|
||||||
1,
|
1,
|
||||||
2,
|
2,
|
||||||
], "latent_averaging mode has to be one of (0, 1, 2)"
|
], "latent_averaging mode has to be one of (0, 1, 2)"
|
||||||
print("mode", latent_averaging_mode)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
voice_samples = [[v.to(self.device) for v in ls] for ls in voice_samples]
|
voice_samples = [[v.to(self.device) for v in ls] for ls in voice_samples]
|
||||||
|
|
||||||
|
@ -368,7 +428,7 @@ class TextToSpeech:
|
||||||
|
|
||||||
diffusion_conds = []
|
diffusion_conds = []
|
||||||
|
|
||||||
DURS_CONST = 102400
|
DURS_CONST = self.args.duration_const
|
||||||
for ls in voice_samples:
|
for ls in voice_samples:
|
||||||
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
||||||
sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1]
|
sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1]
|
||||||
|
@ -428,67 +488,51 @@ class TextToSpeech:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
||||||
|
|
||||||
def tts_with_preset(self, text, preset="fast", **kwargs):
|
def synthesis(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
text,
|
||||||
|
config,
|
||||||
|
preset,
|
||||||
|
speaker_id="lj",
|
||||||
|
):
|
||||||
|
voice_samples, conditioning_latents = load_voice(speaker_id)
|
||||||
|
|
||||||
|
outputs = model.inference_with_config(
|
||||||
|
text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents
|
||||||
|
)
|
||||||
|
|
||||||
|
return_dict = {
|
||||||
|
"wav": outputs["wav"],
|
||||||
|
"deterministic_seed": outputs["deterministic_seed"],
|
||||||
|
"text_inputs": outputs["text"],
|
||||||
|
"voice_samples": outputs["voice_samples"],
|
||||||
|
"conditioning_latents": outputs["conditioning_latents"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
def inference_with_config(self, text, config, **kwargs):
|
||||||
"""
|
"""
|
||||||
Calls TTS with one of a set of preset generation parameters. Options:
|
inference with config
|
||||||
'single_sample': Produces speech even faster, but only produces 1 sample.
|
#TODO describe in detail
|
||||||
'ultra_fast': Produces speech much faster than the original tortoise repo.
|
|
||||||
'ultra_fast_old': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest).
|
|
||||||
'fast': Decent quality speech at a decent inference rate. A good choice for mass inference.
|
|
||||||
'standard': Very good quality. This is generally about as good as you are going to get.
|
|
||||||
'high_quality': Use if you want the absolute best. This is not really worth the compute, though.
|
|
||||||
"""
|
"""
|
||||||
# Use generally found best tuning knobs for generation.
|
# Use generally found best tuning knobs for generation.
|
||||||
settings = {
|
settings = {
|
||||||
"temperature": 0.2,
|
"temperature": config.temperature,
|
||||||
"length_penalty": 1.0,
|
"length_penalty": config.length_penalty,
|
||||||
"repetition_penalty": 2.0,
|
"repetition_penalty": config.repetition_penalty,
|
||||||
"top_p": 0.8,
|
"top_p": config.top_p,
|
||||||
"cond_free_k": 2.0,
|
"cond_free_k": config.cond_free_k,
|
||||||
"diffusion_temperature": 1.0,
|
"diffusion_temperature": config.diffusion_temperature,
|
||||||
|
"num_autoregressive_samples": config.num_autoregressive_samples,
|
||||||
|
"diffusion_iterations": config.diffusion_iterations,
|
||||||
|
"sampler": config.sampler,
|
||||||
}
|
}
|
||||||
# Presets are defined here.
|
|
||||||
presets = {
|
|
||||||
"single_sample": {
|
|
||||||
"num_autoregressive_samples": 8,
|
|
||||||
"diffusion_iterations": 10,
|
|
||||||
"sampler": "ddim",
|
|
||||||
},
|
|
||||||
"ultra_fast": {
|
|
||||||
"num_autoregressive_samples": 16,
|
|
||||||
"diffusion_iterations": 10,
|
|
||||||
"sampler": "ddim",
|
|
||||||
},
|
|
||||||
"ultra_fast_old": {
|
|
||||||
"num_autoregressive_samples": 16,
|
|
||||||
"diffusion_iterations": 30,
|
|
||||||
"cond_free": False,
|
|
||||||
},
|
|
||||||
"very_fast": {
|
|
||||||
"num_autoregressive_samples": 32,
|
|
||||||
"diffusion_iterations": 30,
|
|
||||||
"sampler": "dpm++2m",
|
|
||||||
},
|
|
||||||
"fast": {
|
|
||||||
"num_autoregressive_samples": 16,
|
|
||||||
"diffusion_iterations": 50,
|
|
||||||
"sampler": "ddim",
|
|
||||||
},
|
|
||||||
"fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
|
|
||||||
"standard": {
|
|
||||||
"num_autoregressive_samples": 256,
|
|
||||||
"diffusion_iterations": 200,
|
|
||||||
},
|
|
||||||
"high_quality": {
|
|
||||||
"num_autoregressive_samples": 256,
|
|
||||||
"diffusion_iterations": 400,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
settings.update(presets[preset])
|
|
||||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||||
return self.tts(text, **settings)
|
return self.inference(text, **settings)
|
||||||
|
|
||||||
def tts(
|
def inference(
|
||||||
self,
|
self,
|
||||||
text,
|
text,
|
||||||
voice_samples=None,
|
voice_samples=None,
|
||||||
|
@ -709,11 +753,34 @@ class TextToSpeech:
|
||||||
else:
|
else:
|
||||||
res = wav_candidates[0]
|
res = wav_candidates[0]
|
||||||
|
|
||||||
|
return_dict = {
|
||||||
|
"wav": res,
|
||||||
|
"deterministic_seed": None,
|
||||||
|
"text": None,
|
||||||
|
"voice_samples": None,
|
||||||
|
"conditioning_latents": None,
|
||||||
|
}
|
||||||
if return_deterministic_state:
|
if return_deterministic_state:
|
||||||
return res, (
|
return_dict = {
|
||||||
deterministic_seed,
|
"wav": res,
|
||||||
text,
|
"deterministic_seed": deterministic_seed,
|
||||||
voice_samples,
|
"text": text,
|
||||||
conditioning_latents,
|
"voice_samples": voice_samples,
|
||||||
)
|
"conditioning_latents": conditioning_latents,
|
||||||
return res
|
}
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
def forward():
|
||||||
|
raise NotImplementedError("Tortoise Training is not implemented")
|
||||||
|
|
||||||
|
def eval_step():
|
||||||
|
raise NotImplementedError("Tortoise Training is not implemented")
|
||||||
|
|
||||||
|
def init_from_config():
|
||||||
|
raise NotImplementedError("Tortoise Training is not implemented")
|
||||||
|
|
||||||
|
def load_checkpoint():
|
||||||
|
raise NotImplementedError("Tortoise Training is not implemented")
|
||||||
|
|
||||||
|
def train_step():
|
||||||
|
raise NotImplementedError("Tortoise Training is not implemented")
|
||||||
|
|
Loading…
Reference in New Issue