From 97b29a280e14914eebd21a4d9b65948ace10c760 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 6 Nov 2023 19:19:05 +0100 Subject: [PATCH] Drop diffusion deps in code --- TTS/tts/layers/xtts/vocoder.py | 385 ------------------- TTS/tts/models/xtts.py | 86 +---- recipes/ljspeech/xtts_v1/train_gpt_xtts.py | 1 - recipes/ljspeech/xtts_v2/train_gpt_xtts.py | 1 - tests/xtts_tests/test_xtts_gpt_train.py | 1 - tests/xtts_tests/test_xtts_v2-0_gpt_train.py | 1 - 6 files changed, 16 insertions(+), 459 deletions(-) delete mode 100644 TTS/tts/layers/xtts/vocoder.py diff --git a/TTS/tts/layers/xtts/vocoder.py b/TTS/tts/layers/xtts/vocoder.py deleted file mode 100644 index 0f4991b8..00000000 --- a/TTS/tts/layers/xtts/vocoder.py +++ /dev/null @@ -1,385 +0,0 @@ -import json -from dataclasses import dataclass -from enum import Enum -from typing import Callable, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - -MAX_WAV_VALUE = 32768.0 - - -class KernelPredictor(torch.nn.Module): - """Kernel predictor for the location-variable convolutions""" - - def __init__( - self, - cond_channels, - conv_in_channels, - conv_out_channels, - conv_layers, - conv_kernel_size=3, - kpnet_hidden_channels=64, - kpnet_conv_size=3, - kpnet_dropout=0.0, - kpnet_nonlinear_activation="LeakyReLU", - kpnet_nonlinear_activation_params={"negative_slope": 0.1}, - ): - """ - Args: - cond_channels (int): number of channel for the conditioning sequence, - conv_in_channels (int): number of channel for the input sequence, - conv_out_channels (int): number of channel for the output sequence, - conv_layers (int): number of layers - """ - super().__init__() - - self.conv_in_channels = conv_in_channels - self.conv_out_channels = conv_out_channels - self.conv_kernel_size = conv_kernel_size - self.conv_layers = conv_layers - - kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w - kpnet_bias_channels = conv_out_channels * conv_layers # l_b - - self.input_conv = nn.Sequential( - nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), - ) - - self.residual_convs = nn.ModuleList() - padding = (kpnet_conv_size - 1) // 2 - for _ in range(3): - self.residual_convs.append( - nn.Sequential( - nn.Dropout(kpnet_dropout), - nn.utils.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_hidden_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), - nn.utils.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_hidden_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), - ) - ) - self.kernel_conv = nn.utils.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_kernel_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ) - self.bias_conv = nn.utils.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_bias_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ) - - def forward(self, c): - """ - Args: - c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) - """ - batch, _, cond_length = c.shape - c = self.input_conv(c) - for residual_conv in self.residual_convs: - residual_conv.to(c.device) - c = c + residual_conv(c) - k = self.kernel_conv(c) - b = self.bias_conv(c) - kernels = k.contiguous().view( - batch, - self.conv_layers, - self.conv_in_channels, - self.conv_out_channels, - self.conv_kernel_size, - cond_length, - ) - bias = b.contiguous().view( - batch, - self.conv_layers, - self.conv_out_channels, - cond_length, - ) - - return kernels, bias - - def remove_weight_norm(self): - nn.utils.remove_weight_norm(self.input_conv[0]) - nn.utils.remove_weight_norm(self.kernel_conv) - nn.utils.remove_weight_norm(self.bias_conv) - for block in self.residual_convs: - nn.utils.remove_weight_norm(block[1]) - nn.utils.remove_weight_norm(block[3]) - - -class LVCBlock(torch.nn.Module): - """the location-variable convolutions""" - - def __init__( - self, - in_channels, - cond_channels, - stride, - dilations=[1, 3, 9, 27], - lReLU_slope=0.2, - conv_kernel_size=3, - cond_hop_length=256, - kpnet_hidden_channels=64, - kpnet_conv_size=3, - kpnet_dropout=0.0, - ): - super().__init__() - - self.cond_hop_length = cond_hop_length - self.conv_layers = len(dilations) - self.conv_kernel_size = conv_kernel_size - - self.kernel_predictor = KernelPredictor( - cond_channels=cond_channels, - conv_in_channels=in_channels, - conv_out_channels=2 * in_channels, - conv_layers=len(dilations), - conv_kernel_size=conv_kernel_size, - kpnet_hidden_channels=kpnet_hidden_channels, - kpnet_conv_size=kpnet_conv_size, - kpnet_dropout=kpnet_dropout, - kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, - ) - - self.convt_pre = nn.Sequential( - nn.LeakyReLU(lReLU_slope), - nn.utils.weight_norm( - nn.ConvTranspose1d( - in_channels, - in_channels, - 2 * stride, - stride=stride, - padding=stride // 2 + stride % 2, - output_padding=stride % 2, - ) - ), - ) - - self.conv_blocks = nn.ModuleList() - for dilation in dilations: - self.conv_blocks.append( - nn.Sequential( - nn.LeakyReLU(lReLU_slope), - nn.utils.weight_norm( - nn.Conv1d( - in_channels, - in_channels, - conv_kernel_size, - padding=dilation * (conv_kernel_size - 1) // 2, - dilation=dilation, - ) - ), - nn.LeakyReLU(lReLU_slope), - ) - ) - - def forward(self, x, c): - """forward propagation of the location-variable convolutions. - Args: - x (Tensor): the input sequence (batch, in_channels, in_length) - c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) - - Returns: - Tensor: the output sequence (batch, in_channels, in_length) - """ - _, in_channels, _ = x.shape # (B, c_g, L') - - x = self.convt_pre(x) # (B, c_g, stride * L') - kernels, bias = self.kernel_predictor(c) - - for i, conv in enumerate(self.conv_blocks): - output = conv(x) # (B, c_g, stride * L') - - k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) - b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) - - output = self.location_variable_convolution( - output, k, b, hop_size=self.cond_hop_length - ) # (B, 2 * c_g, stride * L'): LVC - x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( - output[:, in_channels:, :] - ) # (B, c_g, stride * L'): GAU - - return x - - def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): - """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. - Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. - Args: - x (Tensor): the input sequence (batch, in_channels, in_length). - kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) - bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) - dilation (int): the dilation of convolution. - hop_size (int): the hop_size of the conditioning sequence. - Returns: - (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). - """ - batch, _, in_length = x.shape - batch, _, out_channels, kernel_size, kernel_length = kernel.shape - assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" - - padding = dilation * int((kernel_size - 1) / 2) - x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) - x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) - - if hop_size < dilation: - x = F.pad(x, (0, dilation), "constant", 0) - x = x.unfold( - 3, dilation, dilation - ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) - x = x[:, :, :, :, :hop_size] - x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) - x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) - - o = torch.einsum("bildsk,biokl->bolsd", x, kernel) - o = o.to(memory_format=torch.channels_last_3d) - bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) - o = o + bias - o = o.contiguous().view(batch, out_channels, -1) - - return o - - def remove_weight_norm(self): - self.kernel_predictor.remove_weight_norm() - nn.utils.remove_weight_norm(self.convt_pre[1]) - for block in self.conv_blocks: - nn.utils.remove_weight_norm(block[1]) - - -class UnivNetGenerator(nn.Module): - """ - UnivNet Generator - - Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py. - """ - - def __init__( - self, - noise_dim=64, - channel_size=32, - dilations=[1, 3, 9, 27], - strides=[8, 8, 4], - lReLU_slope=0.2, - kpnet_conv_size=3, - # Below are MEL configurations options that this generator requires. - hop_length=256, - n_mel_channels=100, - ): - super(UnivNetGenerator, self).__init__() - self.mel_channel = n_mel_channels - self.noise_dim = noise_dim - self.hop_length = hop_length - channel_size = channel_size - kpnet_conv_size = kpnet_conv_size - - self.res_stack = nn.ModuleList() - hop_length = 1 - for stride in strides: - hop_length = stride * hop_length - self.res_stack.append( - LVCBlock( - channel_size, - n_mel_channels, - stride=stride, - dilations=dilations, - lReLU_slope=lReLU_slope, - cond_hop_length=hop_length, - kpnet_conv_size=kpnet_conv_size, - ) - ) - - self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")) - - self.conv_post = nn.Sequential( - nn.LeakyReLU(lReLU_slope), - nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")), - nn.Tanh(), - ) - - def forward(self, c, z): - """ - Args: - c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length) - z (Tensor): the noise sequence (batch, noise_dim, in_length) - - """ - z = self.conv_pre(z) # (B, c_g, L) - - for res_block in self.res_stack: - res_block.to(z.device) - z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i) - - z = self.conv_post(z) # (B, 1, L * 256) - - return z - - def eval(self, inference=False): - super(UnivNetGenerator, self).eval() - # don't remove weight norm while validation in training loop - if inference: - self.remove_weight_norm() - - def remove_weight_norm(self): - nn.utils.remove_weight_norm(self.conv_pre) - - for layer in self.conv_post: - if len(layer.state_dict()) != 0: - nn.utils.remove_weight_norm(layer) - - for res_block in self.res_stack: - res_block.remove_weight_norm() - - def inference(self, c, z=None): - # pad input mel with zeros to cut artifact - # see https://github.com/seungwonpark/melgan/issues/8 - zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device) - mel = torch.cat((c, zero), dim=2) - - if z is None: - z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device) - - audio = self.forward(mel, z) - audio = audio[:, :, : -(self.hop_length * 10)] - audio = audio.clamp(min=-1, max=1) - return audio - - -if __name__ == "__main__": - model = UnivNetGenerator() - - c = torch.randn(3, 100, 10) - z = torch.randn(3, 64, 10) - print(c.shape) - - y = model(c, z) - print(y.shape) - assert y.shape == torch.Size([3, 1, 2560]) - - pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(pytorch_total_params) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 6b8a73e8..af94675b 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -13,7 +13,6 @@ from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer -from TTS.tts.layers.xtts.vocoder import UnivNetGenerator from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec @@ -185,7 +184,6 @@ class XttsArgs(Coqpit): clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. - use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True. For GPT model: gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. @@ -223,7 +221,6 @@ class XttsArgs(Coqpit): clvp_checkpoint: str = None decoder_checkpoint: str = None num_chars: int = 255 - use_hifigan: bool = True # XTTS GPT Encoder params tokenizer_file: str = "" @@ -320,32 +317,15 @@ class Xtts(BaseTTS): code_stride_len=self.args.gpt_code_stride_len, ) - if self.args.use_hifigan: - self.hifigan_decoder = HifiDecoder( - input_sample_rate=self.args.input_sample_rate, - output_sample_rate=self.args.output_sample_rate, - output_hop_length=self.args.output_hop_length, - ar_mel_length_compression=self.args.gpt_code_stride_len, - decoder_input_dim=self.args.decoder_input_dim, - d_vector_dim=self.args.d_vector_dim, - cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, - ) - - if not self.args.use_hifigan: - self.diffusion_decoder = DiffusionTts( - model_channels=self.args.diff_model_channels, - num_layers=self.args.diff_num_layers, - in_channels=self.args.diff_in_channels, - out_channels=self.args.diff_out_channels, - in_latent_channels=self.args.diff_in_latent_channels, - in_tokens=self.args.diff_in_tokens, - dropout=self.args.diff_dropout, - use_fp16=self.args.diff_use_fp16, - num_heads=self.args.diff_num_heads, - layer_drop=self.args.diff_layer_drop, - unconditioned_percentage=self.args.diff_unconditioned_percentage, - ) - self.vocoder = UnivNetGenerator() + self.hifigan_decoder = HifiDecoder( + input_sample_rate=self.args.input_sample_rate, + output_sample_rate=self.args.output_sample_rate, + output_hop_length=self.args.output_hop_length, + ar_mel_length_compression=self.args.gpt_code_stride_len, + decoder_input_dim=self.args.decoder_input_dim, + d_vector_dim=self.args.d_vector_dim, + cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, + ) @property def device(self): @@ -426,7 +406,6 @@ class Xtts(BaseTTS): sound_norm_refs=False, ): speaker_embedding = None - diffusion_cond_latents = None audio, sr = torchaudio.load(audio_path) audio = audio[:, : sr * max_ref_length].to(self.device) @@ -437,12 +416,9 @@ class Xtts(BaseTTS): if librosa_trim_db is not None: audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] - if self.args.use_hifigan or self.args.use_hifigan: - speaker_embedding = self.get_speaker_embedding(audio, sr) - else: - diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr) + speaker_embedding = self.get_speaker_embedding(audio, sr) gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T] - return gpt_cond_latents, diffusion_cond_latents, speaker_embedding + return gpt_cond_latents, speaker_embedding def synthesize(self, text, config, speaker_wav, language, **kwargs): """Synthesize speech with the given input text. @@ -575,7 +551,7 @@ class Xtts(BaseTTS): Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Sample rate is 24kHz. """ - (gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents( + (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents( audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len, max_ref_length=max_ref_len, @@ -587,7 +563,6 @@ class Xtts(BaseTTS): language, gpt_cond_latent, speaker_embedding, - diffusion_conditioning, temperature=temperature, length_penalty=length_penalty, repetition_penalty=repetition_penalty, @@ -610,7 +585,6 @@ class Xtts(BaseTTS): language, gpt_cond_latent, speaker_embedding, - diffusion_conditioning, # GPT inference temperature=0.65, length_penalty=1, @@ -639,14 +613,6 @@ class Xtts(BaseTTS): text_tokens.shape[-1] < self.args.gpt_max_text_tokens ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - if not self.args.use_hifigan: - diffuser = load_discrete_vocoder_diffuser( - desired_diffusion_steps=decoder_iterations, - cond_free=cond_free, - cond_free_k=cond_free_k, - sampler=decoder_sampler, - ) - with torch.no_grad(): gpt_codes = self.gpt.generate( cond_latents=gpt_cond_latent, @@ -688,11 +654,7 @@ class Xtts(BaseTTS): gpt_latents = gpt_latents[:, :k] break - if decoder == "hifigan": - assert ( - hasattr(self, "hifigan_decoder") and self.hifigan_decoder is not None - ), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" - wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) + wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) return { "wav": wav.cpu().numpy().squeeze(), @@ -735,9 +697,6 @@ class Xtts(BaseTTS): decoder="hifigan", **hf_generate_kwargs, ): - assert hasattr( - self, "hifigan_decoder" - ), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." text = text.strip().lower() text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) @@ -776,13 +735,7 @@ class Xtts(BaseTTS): if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): gpt_latents = torch.cat(all_latents, dim=0)[None, :] - if decoder == "hifigan": - assert hasattr( - self, "hifigan_decoder" - ), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" - wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) - else: - raise NotImplementedError("Diffusion for streaming inference not implemented.") + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len ) @@ -810,10 +763,8 @@ class Xtts(BaseTTS): def get_compatible_checkpoint_state_dict(self, model_path): checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] - ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else [] - ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"] # remove xtts gpt trainer extra keys - ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"] + ignore_keys = ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"] for key in list(checkpoint.keys()): # check if it is from the coqui Trainer if so convert it if key.startswith("xtts."): @@ -872,12 +823,7 @@ class Xtts(BaseTTS): self.load_state_dict(checkpoint, strict=strict) if eval: - if hasattr(self, "hifigan_decoder"): - self.hifigan_decoder.eval() - if hasattr(self, "diffusion_decoder"): - self.diffusion_decoder.eval() - if hasattr(self, "vocoder"): - self.vocoder.eval() + self.hifigan_decoder.eval() self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) self.gpt.eval() diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index 02e35dfd..268a0335 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -94,7 +94,6 @@ def main(): gpt_num_audio_tokens=8194, gpt_start_audio_token=8192, gpt_stop_audio_token=8193, - use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint ) # define audio config audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) diff --git a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py index 4d06fed1..d94204ca 100644 --- a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py @@ -93,7 +93,6 @@ def main(): gpt_num_audio_tokens=8194, gpt_start_audio_token=8192, gpt_stop_audio_token=8193, - use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint gpt_use_masking_gt_prompt_approach=True, gpt_use_perceiver_resampler=True, ) diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index 83cf537f..03514daa 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -86,7 +86,6 @@ model_args = GPTArgs( gpt_num_audio_tokens=8194, gpt_start_audio_token=8192, gpt_stop_audio_token=8193, - use_ne_hifigan=True, ) audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) config = GPTTrainerConfig( diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index b9f6438e..80995038 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -86,7 +86,6 @@ model_args = GPTArgs( gpt_stop_audio_token=8193, gpt_use_masking_gt_prompt_approach=True, gpt_use_perceiver_resampler=True, - use_ne_hifigan=True, ) audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) config = GPTTrainerConfig(