mirror of https://github.com/coqui-ai/TTS.git
Drop use_ne_hifigan
This commit is contained in:
parent
9d54bd7655
commit
9bbf6eb8dd
|
@ -190,7 +190,6 @@ class XttsArgs(Coqpit):
|
|||
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
||||
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
||||
use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True.
|
||||
use_ne_hifigan (bool, optional): Whether to use regular hifigan or diffusion + univnet as a decoder. Defaults to False.
|
||||
|
||||
For GPT model:
|
||||
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
||||
|
@ -229,7 +228,6 @@ class XttsArgs(Coqpit):
|
|||
decoder_checkpoint: str = None
|
||||
num_chars: int = 255
|
||||
use_hifigan: bool = True
|
||||
use_ne_hifigan: bool = False
|
||||
|
||||
# XTTS GPT Encoder params
|
||||
tokenizer_file: str = ""
|
||||
|
@ -337,18 +335,7 @@ class Xtts(BaseTTS):
|
|||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||
)
|
||||
|
||||
if self.args.use_ne_hifigan:
|
||||
self.ne_hifigan_decoder = HifiDecoder(
|
||||
input_sample_rate=self.args.input_sample_rate,
|
||||
output_sample_rate=self.args.output_sample_rate,
|
||||
output_hop_length=self.args.output_hop_length,
|
||||
ar_mel_length_compression=self.args.gpt_code_stride_len,
|
||||
decoder_input_dim=self.args.decoder_input_dim,
|
||||
d_vector_dim=self.args.d_vector_dim,
|
||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||
)
|
||||
|
||||
if not (self.args.use_hifigan or self.args.use_ne_hifigan):
|
||||
if not self.args.use_hifigan:
|
||||
self.diffusion_decoder = DiffusionTts(
|
||||
model_channels=self.args.diff_model_channels,
|
||||
num_layers=self.args.diff_num_layers,
|
||||
|
@ -454,7 +441,7 @@ class Xtts(BaseTTS):
|
|||
if librosa_trim_db is not None:
|
||||
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
||||
|
||||
if self.args.use_hifigan or self.args.use_ne_hifigan:
|
||||
if self.args.use_hifigan or self.args.use_hifigan:
|
||||
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
||||
else:
|
||||
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
|
||||
|
@ -706,19 +693,14 @@ class Xtts(BaseTTS):
|
|||
break
|
||||
|
||||
if decoder == "hifigan":
|
||||
assert hasattr(
|
||||
self, "hifigan_decoder"
|
||||
assert (
|
||||
hasattr(self, "hifigan_decoder") and self.hifigan_decoder is not None
|
||||
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
elif decoder == "ne_hifigan":
|
||||
assert hasattr(
|
||||
self, "ne_hifigan_decoder"
|
||||
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
else:
|
||||
assert hasattr(
|
||||
self, "diffusion_decoder"
|
||||
), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
|
||||
), "You must disable hifigan decoders to use difffusion by setting `use_hifigan: false`"
|
||||
mel = do_spectrogram_diffusion(
|
||||
self.diffusion_decoder,
|
||||
diffuser,
|
||||
|
@ -816,11 +798,6 @@ class Xtts(BaseTTS):
|
|||
self, "hifigan_decoder"
|
||||
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
elif decoder == "ne_hifigan":
|
||||
assert hasattr(
|
||||
self, "ne_hifigan_decoder"
|
||||
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
else:
|
||||
raise NotImplementedError("Diffusion for streaming inference not implemented.")
|
||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||
|
@ -850,9 +827,8 @@ class Xtts(BaseTTS):
|
|||
|
||||
def get_compatible_checkpoint_state_dict(self, model_path):
|
||||
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
|
||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else []
|
||||
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
||||
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
|
||||
# remove xtts gpt trainer extra keys
|
||||
ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
|
||||
for key in list(checkpoint.keys()):
|
||||
|
@ -915,8 +891,6 @@ class Xtts(BaseTTS):
|
|||
if eval:
|
||||
if hasattr(self, "hifigan_decoder"):
|
||||
self.hifigan_decoder.eval()
|
||||
if hasattr(self, "ne_hifigan_decoder"):
|
||||
self.ne_hifigan_decoder.eval()
|
||||
if hasattr(self, "diffusion_decoder"):
|
||||
self.diffusion_decoder.eval()
|
||||
if hasattr(self, "vocoder"):
|
||||
|
|
Loading…
Reference in New Issue