diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 3e413114..de38f7c8 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -190,7 +190,6 @@ class XttsArgs(Coqpit): decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True. - use_ne_hifigan (bool, optional): Whether to use regular hifigan or diffusion + univnet as a decoder. Defaults to False. For GPT model: gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. @@ -229,7 +228,6 @@ class XttsArgs(Coqpit): decoder_checkpoint: str = None num_chars: int = 255 use_hifigan: bool = True - use_ne_hifigan: bool = False # XTTS GPT Encoder params tokenizer_file: str = "" @@ -337,18 +335,7 @@ class Xtts(BaseTTS): cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, ) - if self.args.use_ne_hifigan: - self.ne_hifigan_decoder = HifiDecoder( - input_sample_rate=self.args.input_sample_rate, - output_sample_rate=self.args.output_sample_rate, - output_hop_length=self.args.output_hop_length, - ar_mel_length_compression=self.args.gpt_code_stride_len, - decoder_input_dim=self.args.decoder_input_dim, - d_vector_dim=self.args.d_vector_dim, - cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, - ) - - if not (self.args.use_hifigan or self.args.use_ne_hifigan): + if not self.args.use_hifigan: self.diffusion_decoder = DiffusionTts( model_channels=self.args.diff_model_channels, num_layers=self.args.diff_num_layers, @@ -454,7 +441,7 @@ class Xtts(BaseTTS): if librosa_trim_db is not None: audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] - if self.args.use_hifigan or self.args.use_ne_hifigan: + if self.args.use_hifigan or self.args.use_hifigan: speaker_embedding = self.get_speaker_embedding(audio, sr) else: diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr) @@ -706,19 +693,14 @@ class Xtts(BaseTTS): break if decoder == "hifigan": - assert hasattr( - self, "hifigan_decoder" + assert ( + hasattr(self, "hifigan_decoder") and self.hifigan_decoder is not None ), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) - elif decoder == "ne_hifigan": - assert hasattr( - self, "ne_hifigan_decoder" - ), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" - wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding) else: assert hasattr( self, "diffusion_decoder" - ), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`" + ), "You must disable hifigan decoders to use difffusion by setting `use_hifigan: false`" mel = do_spectrogram_diffusion( self.diffusion_decoder, diffuser, @@ -816,11 +798,6 @@ class Xtts(BaseTTS): self, "hifigan_decoder" ), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) - elif decoder == "ne_hifigan": - assert hasattr( - self, "ne_hifigan_decoder" - ), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" - wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) else: raise NotImplementedError("Diffusion for streaming inference not implemented.") wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( @@ -850,9 +827,8 @@ class Xtts(BaseTTS): def get_compatible_checkpoint_state_dict(self, model_path): checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] - ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else [] + ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else [] ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"] - ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"] # remove xtts gpt trainer extra keys ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"] for key in list(checkpoint.keys()): @@ -915,8 +891,6 @@ class Xtts(BaseTTS): if eval: if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval() - if hasattr(self, "ne_hifigan_decoder"): - self.ne_hifigan_decoder.eval() if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval() if hasattr(self, "vocoder"):