Drop use_ne_hifigan

This commit is contained in:
Eren G??lge 2023-11-06 18:43:38 +01:00
parent 9d54bd7655
commit 9bbf6eb8dd
1 changed files with 6 additions and 32 deletions

View File

@ -190,7 +190,6 @@ class XttsArgs(Coqpit):
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True.
use_ne_hifigan (bool, optional): Whether to use regular hifigan or diffusion + univnet as a decoder. Defaults to False.
For GPT model:
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
@ -229,7 +228,6 @@ class XttsArgs(Coqpit):
decoder_checkpoint: str = None
num_chars: int = 255
use_hifigan: bool = True
use_ne_hifigan: bool = False
# XTTS GPT Encoder params
tokenizer_file: str = ""
@ -337,18 +335,7 @@ class Xtts(BaseTTS):
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
if self.args.use_ne_hifigan:
self.ne_hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate,
output_sample_rate=self.args.output_sample_rate,
output_hop_length=self.args.output_hop_length,
ar_mel_length_compression=self.args.gpt_code_stride_len,
decoder_input_dim=self.args.decoder_input_dim,
d_vector_dim=self.args.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
if not (self.args.use_hifigan or self.args.use_ne_hifigan):
if not self.args.use_hifigan:
self.diffusion_decoder = DiffusionTts(
model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers,
@ -454,7 +441,7 @@ class Xtts(BaseTTS):
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
if self.args.use_hifigan or self.args.use_ne_hifigan:
if self.args.use_hifigan or self.args.use_hifigan:
speaker_embedding = self.get_speaker_embedding(audio, sr)
else:
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
@ -706,19 +693,14 @@ class Xtts(BaseTTS):
break
if decoder == "hifigan":
assert hasattr(
self, "hifigan_decoder"
assert (
hasattr(self, "hifigan_decoder") and self.hifigan_decoder is not None
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
elif decoder == "ne_hifigan":
assert hasattr(
self, "ne_hifigan_decoder"
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
else:
assert hasattr(
self, "diffusion_decoder"
), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
), "You must disable hifigan decoders to use difffusion by setting `use_hifigan: false`"
mel = do_spectrogram_diffusion(
self.diffusion_decoder,
diffuser,
@ -816,11 +798,6 @@ class Xtts(BaseTTS):
self, "hifigan_decoder"
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
elif decoder == "ne_hifigan":
assert hasattr(
self, "ne_hifigan_decoder"
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
else:
raise NotImplementedError("Diffusion for streaming inference not implemented.")
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
@ -850,9 +827,8 @@ class Xtts(BaseTTS):
def get_compatible_checkpoint_state_dict(self, model_path):
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else []
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
# remove xtts gpt trainer extra keys
ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
for key in list(checkpoint.keys()):
@ -915,8 +891,6 @@ class Xtts(BaseTTS):
if eval:
if hasattr(self, "hifigan_decoder"):
self.hifigan_decoder.eval()
if hasattr(self, "ne_hifigan_decoder"):
self.ne_hifigan_decoder.eval()
if hasattr(self, "diffusion_decoder"):
self.diffusion_decoder.eval()
if hasattr(self, "vocoder"):