Drop diffusion from XTTS (#3150)

* Drop diffusion for XTTS

* Make style

* Drop diffusion deps in code

* Restore thrashed
This commit is contained in:
Eren Gölge 2023-11-06 20:15:49 +01:00 committed by GitHub
parent 5d418bb84a
commit f0cb19ecca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 33 additions and 1810 deletions

View File

@ -1548,4 +1548,4 @@ def expand_dims(v, dims):
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,) * (dims - 1)]
return v[(...,) + (None,) * (dims - 1)]

File diff suppressed because it is too large Load Diff

View File

@ -1,385 +0,0 @@
import json
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
MAX_WAV_VALUE = 32768.0
class KernelPredictor(torch.nn.Module):
"""Kernel predictor for the location-variable convolutions"""
def __init__(
self,
cond_channels,
conv_in_channels,
conv_out_channels,
conv_layers,
conv_kernel_size=3,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
kpnet_nonlinear_activation="LeakyReLU",
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
):
"""
Args:
cond_channels (int): number of channel for the conditioning sequence,
conv_in_channels (int): number of channel for the input sequence,
conv_out_channels (int): number of channel for the output sequence,
conv_layers (int): number of layers
"""
super().__init__()
self.conv_in_channels = conv_in_channels
self.conv_out_channels = conv_out_channels
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.residual_convs = nn.ModuleList()
padding = (kpnet_conv_size - 1) // 2
for _ in range(3):
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_kernel_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
)
self.bias_conv = nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_bias_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
)
def forward(self, c):
"""
Args:
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
"""
batch, _, cond_length = c.shape
c = self.input_conv(c)
for residual_conv in self.residual_convs:
residual_conv.to(c.device)
c = c + residual_conv(c)
k = self.kernel_conv(c)
b = self.bias_conv(c)
kernels = k.contiguous().view(
batch,
self.conv_layers,
self.conv_in_channels,
self.conv_out_channels,
self.conv_kernel_size,
cond_length,
)
bias = b.contiguous().view(
batch,
self.conv_layers,
self.conv_out_channels,
cond_length,
)
return kernels, bias
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0])
nn.utils.remove_weight_norm(self.kernel_conv)
nn.utils.remove_weight_norm(self.bias_conv)
for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[3])
class LVCBlock(torch.nn.Module):
"""the location-variable convolutions"""
def __init__(
self,
in_channels,
cond_channels,
stride,
dilations=[1, 3, 9, 27],
lReLU_slope=0.2,
conv_kernel_size=3,
cond_hop_length=256,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
):
super().__init__()
self.cond_hop_length = cond_hop_length
self.conv_layers = len(dilations)
self.conv_kernel_size = conv_kernel_size
self.kernel_predictor = KernelPredictor(
cond_channels=cond_channels,
conv_in_channels=in_channels,
conv_out_channels=2 * in_channels,
conv_layers=len(dilations),
conv_kernel_size=conv_kernel_size,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=kpnet_dropout,
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
)
self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.ConvTranspose1d(
in_channels,
in_channels,
2 * stride,
stride=stride,
padding=stride // 2 + stride % 2,
output_padding=stride % 2,
)
),
)
self.conv_blocks = nn.ModuleList()
for dilation in dilations:
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.Conv1d(
in_channels,
in_channels,
conv_kernel_size,
padding=dilation * (conv_kernel_size - 1) // 2,
dilation=dilation,
)
),
nn.LeakyReLU(lReLU_slope),
)
)
def forward(self, x, c):
"""forward propagation of the location-variable convolutions.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length)
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
Tensor: the output sequence (batch, in_channels, in_length)
"""
_, in_channels, _ = x.shape # (B, c_g, L')
x = self.convt_pre(x) # (B, c_g, stride * L')
kernels, bias = self.kernel_predictor(c)
for i, conv in enumerate(self.conv_blocks):
output = conv(x) # (B, c_g, stride * L')
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
output = self.location_variable_convolution(
output, k, b, hop_size=self.cond_hop_length
) # (B, 2 * c_g, stride * L'): LVC
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
output[:, in_channels:, :]
) # (B, c_g, stride * L'): GAU
return x
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length).
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
dilation (int): the dilation of convolution.
hop_size (int): the hop_size of the conditioning sequence.
Returns:
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
"""
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), "constant", 0)
x = x.unfold(
3, dilation, dilation
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
o = o.to(memory_format=torch.channels_last_3d)
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
o = o + bias
o = o.contiguous().view(batch, out_channels, -1)
return o
def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1])
for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1])
class UnivNetGenerator(nn.Module):
"""
UnivNet Generator
Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py.
"""
def __init__(
self,
noise_dim=64,
channel_size=32,
dilations=[1, 3, 9, 27],
strides=[8, 8, 4],
lReLU_slope=0.2,
kpnet_conv_size=3,
# Below are MEL configurations options that this generator requires.
hop_length=256,
n_mel_channels=100,
):
super(UnivNetGenerator, self).__init__()
self.mel_channel = n_mel_channels
self.noise_dim = noise_dim
self.hop_length = hop_length
channel_size = channel_size
kpnet_conv_size = kpnet_conv_size
self.res_stack = nn.ModuleList()
hop_length = 1
for stride in strides:
hop_length = stride * hop_length
self.res_stack.append(
LVCBlock(
channel_size,
n_mel_channels,
stride=stride,
dilations=dilations,
lReLU_slope=lReLU_slope,
cond_hop_length=hop_length,
kpnet_conv_size=kpnet_conv_size,
)
)
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.Tanh(),
)
def forward(self, c, z):
"""
Args:
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
z (Tensor): the noise sequence (batch, noise_dim, in_length)
"""
z = self.conv_pre(z) # (B, c_g, L)
for res_block in self.res_stack:
res_block.to(z.device)
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
z = self.conv_post(z) # (B, 1, L * 256)
return z
def eval(self, inference=False):
super(UnivNetGenerator, self).eval()
# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)
for layer in self.conv_post:
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
for res_block in self.res_stack:
res_block.remove_weight_norm()
def inference(self, c, z=None):
# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
mel = torch.cat((c, zero), dim=2)
if z is None:
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
audio = self.forward(mel, z)
audio = audio[:, :, : -(self.hop_length * 10)]
audio = audio.clamp(min=-1, max=1)
return audio
if __name__ == "__main__":
model = UnivNetGenerator()
c = torch.randn(3, 100, 10)
z = torch.randn(3, 64, 10)
print(c.shape)
y = model(c, z)
print(y.shape)
assert y.shape == torch.Size([3, 1, 2560])
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

View File

@ -252,7 +252,12 @@ class BaseTacotron(BaseTTS):
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
"""Capacitron Variational Autoencoder"""
(VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer(
(
VAE_outputs,
posterior_distribution,
prior_distribution,
capacitron_beta,
) = self.capacitron_vae_layer(
reference_mel_info,
text_info,
speaker_embedding, # pylint: disable=not-callable

View File

@ -676,7 +676,12 @@ class Tortoise(BaseTTS):
), "Too much text provided. Break the text up into separate segments and re-try inference."
if voice_samples is not None:
(auto_conditioning, diffusion_conditioning, _, _,) = self.get_conditioning_latents(
(
auto_conditioning,
diffusion_conditioning,
_,
_,
) = self.get_conditioning_latents(
voice_samples,
return_mels=True,
latent_averaging_mode=latent_averaging_mode,

View File

@ -9,13 +9,10 @@ import torchaudio
from coqpit import Coqpit
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
@ -168,12 +165,10 @@ class XttsAudioConfig(Coqpit):
Args:
sample_rate (int): The sample rate in which the GPT operates.
diffusion_sample_rate (int): The sample rate of the diffusion audio waveform.
output_sample_rate (int): The sample rate of the output audio waveform.
"""
sample_rate: int = 22050
diffusion_sample_rate: int = 24000
output_sample_rate: int = 24000
@ -189,7 +184,6 @@ class XttsArgs(Coqpit):
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True.
For GPT model:
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
@ -227,7 +221,6 @@ class XttsArgs(Coqpit):
clvp_checkpoint: str = None
decoder_checkpoint: str = None
num_chars: int = 255
use_hifigan: bool = True
# XTTS GPT Encoder params
tokenizer_file: str = ""
@ -324,32 +317,15 @@ class Xtts(BaseTTS):
code_stride_len=self.args.gpt_code_stride_len,
)
if self.args.use_hifigan:
self.hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate,
output_sample_rate=self.args.output_sample_rate,
output_hop_length=self.args.output_hop_length,
ar_mel_length_compression=self.args.gpt_code_stride_len,
decoder_input_dim=self.args.decoder_input_dim,
d_vector_dim=self.args.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
if not self.args.use_hifigan:
self.diffusion_decoder = DiffusionTts(
model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers,
in_channels=self.args.diff_in_channels,
out_channels=self.args.diff_out_channels,
in_latent_channels=self.args.diff_in_latent_channels,
in_tokens=self.args.diff_in_tokens,
dropout=self.args.diff_dropout,
use_fp16=self.args.diff_use_fp16,
num_heads=self.args.diff_num_heads,
layer_drop=self.args.diff_layer_drop,
unconditioned_percentage=self.args.diff_unconditioned_percentage,
)
self.vocoder = UnivNetGenerator()
self.hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate,
output_sample_rate=self.args.output_sample_rate,
output_hop_length=self.args.output_hop_length,
ar_mel_length_compression=self.args.gpt_code_stride_len,
decoder_input_dim=self.args.decoder_input_dim,
d_vector_dim=self.args.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
@property
def device(self):
@ -430,7 +406,6 @@ class Xtts(BaseTTS):
sound_norm_refs=False,
):
speaker_embedding = None
diffusion_cond_latents = None
audio, sr = torchaudio.load(audio_path)
audio = audio[:, : sr * max_ref_length].to(self.device)
@ -441,12 +416,9 @@ class Xtts(BaseTTS):
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
if self.args.use_hifigan or self.args.use_hifigan:
speaker_embedding = self.get_speaker_embedding(audio, sr)
else:
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
speaker_embedding = self.get_speaker_embedding(audio, sr)
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
return gpt_cond_latents, speaker_embedding
def synthesize(self, text, config, speaker_wav, language, **kwargs):
"""Synthesize speech with the given input text.
@ -579,7 +551,7 @@ class Xtts(BaseTTS):
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz.
"""
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
(gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
audio_path=ref_audio_path,
gpt_cond_len=gpt_cond_len,
max_ref_length=max_ref_len,
@ -591,7 +563,6 @@ class Xtts(BaseTTS):
language,
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning,
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
@ -614,7 +585,6 @@ class Xtts(BaseTTS):
language,
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning,
# GPT inference
temperature=0.65,
length_penalty=1,
@ -643,14 +613,6 @@ class Xtts(BaseTTS):
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
if not self.args.use_hifigan:
diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=decoder_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
sampler=decoder_sampler,
)
with torch.no_grad():
gpt_codes = self.gpt.generate(
cond_latents=gpt_cond_latent,
@ -692,29 +654,12 @@ class Xtts(BaseTTS):
gpt_latents = gpt_latents[:, :k]
break
if decoder == "hifigan":
assert (
hasattr(self, "hifigan_decoder") and self.hifigan_decoder is not None
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
else:
assert hasattr(
self, "diffusion_decoder"
), "You must disable hifigan decoders to use difffusion by setting `use_hifigan: false`"
mel = do_spectrogram_diffusion(
self.diffusion_decoder,
diffuser,
gpt_latents,
diffusion_conditioning,
temperature=diffusion_temperature,
)
wav = self.vocoder.inference(mel)
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
return {
"wav": wav.cpu().numpy().squeeze(),
"gpt_latents": gpt_latents,
"speaker_embedding": speaker_embedding,
"diffusion_conditioning": diffusion_conditioning,
}
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
@ -752,9 +697,6 @@ class Xtts(BaseTTS):
decoder="hifigan",
**hf_generate_kwargs,
):
assert hasattr(
self, "hifigan_decoder"
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
text = text.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
@ -793,13 +735,7 @@ class Xtts(BaseTTS):
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if decoder == "hifigan":
assert hasattr(
self, "hifigan_decoder"
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
else:
raise NotImplementedError("Diffusion for streaming inference not implemented.")
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
)
@ -827,10 +763,8 @@ class Xtts(BaseTTS):
def get_compatible_checkpoint_state_dict(self, model_path):
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else []
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
# remove xtts gpt trainer extra keys
ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
ignore_keys = ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
for key in list(checkpoint.keys()):
# check if it is from the coqui Trainer if so convert it
if key.startswith("xtts."):
@ -889,12 +823,7 @@ class Xtts(BaseTTS):
self.load_state_dict(checkpoint, strict=strict)
if eval:
if hasattr(self, "hifigan_decoder"):
self.hifigan_decoder.eval()
if hasattr(self, "diffusion_decoder"):
self.diffusion_decoder.eval()
if hasattr(self, "vocoder"):
self.vocoder.eval()
self.hifigan_decoder.eval()
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
self.gpt.eval()

View File

@ -94,12 +94,9 @@ def main():
gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192,
gpt_stop_audio_token=8193,
use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint
)
# define audio config
audio_config = XttsAudioConfig(
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
)
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
# training parameters config
config = GPTTrainerConfig(
output_path=OUT_PATH,

View File

@ -93,14 +93,11 @@ def main():
gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192,
gpt_stop_audio_token=8193,
use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint
gpt_use_masking_gt_prompt_approach=True,
gpt_use_perceiver_resampler=True,
)
# define audio config
audio_config = XttsAudioConfig(
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
)
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
# training parameters config
config = GPTTrainerConfig(
output_path=OUT_PATH,

View File

@ -86,11 +86,8 @@ model_args = GPTArgs(
gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192,
gpt_stop_audio_token=8193,
use_ne_hifigan=True,
)
audio_config = XttsAudioConfig(
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
)
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
config = GPTTrainerConfig(
epochs=1,
output_path=OUT_PATH,

View File

@ -86,11 +86,8 @@ model_args = GPTArgs(
gpt_stop_audio_token=8193,
gpt_use_masking_gt_prompt_approach=True,
gpt_use_perceiver_resampler=True,
use_ne_hifigan=True,
)
audio_config = XttsAudioConfig(
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
)
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
config = GPTTrainerConfig(
epochs=1,
output_path=OUT_PATH,