mirror of https://github.com/coqui-ai/TTS.git
Drop use_ne_hifigan
This commit is contained in:
parent
9d54bd7655
commit
9bbf6eb8dd
|
@ -190,7 +190,6 @@ class XttsArgs(Coqpit):
|
||||||
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
||||||
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
||||||
use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True.
|
use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True.
|
||||||
use_ne_hifigan (bool, optional): Whether to use regular hifigan or diffusion + univnet as a decoder. Defaults to False.
|
|
||||||
|
|
||||||
For GPT model:
|
For GPT model:
|
||||||
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
||||||
|
@ -229,7 +228,6 @@ class XttsArgs(Coqpit):
|
||||||
decoder_checkpoint: str = None
|
decoder_checkpoint: str = None
|
||||||
num_chars: int = 255
|
num_chars: int = 255
|
||||||
use_hifigan: bool = True
|
use_hifigan: bool = True
|
||||||
use_ne_hifigan: bool = False
|
|
||||||
|
|
||||||
# XTTS GPT Encoder params
|
# XTTS GPT Encoder params
|
||||||
tokenizer_file: str = ""
|
tokenizer_file: str = ""
|
||||||
|
@ -337,18 +335,7 @@ class Xtts(BaseTTS):
|
||||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_ne_hifigan:
|
if not self.args.use_hifigan:
|
||||||
self.ne_hifigan_decoder = HifiDecoder(
|
|
||||||
input_sample_rate=self.args.input_sample_rate,
|
|
||||||
output_sample_rate=self.args.output_sample_rate,
|
|
||||||
output_hop_length=self.args.output_hop_length,
|
|
||||||
ar_mel_length_compression=self.args.gpt_code_stride_len,
|
|
||||||
decoder_input_dim=self.args.decoder_input_dim,
|
|
||||||
d_vector_dim=self.args.d_vector_dim,
|
|
||||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not (self.args.use_hifigan or self.args.use_ne_hifigan):
|
|
||||||
self.diffusion_decoder = DiffusionTts(
|
self.diffusion_decoder = DiffusionTts(
|
||||||
model_channels=self.args.diff_model_channels,
|
model_channels=self.args.diff_model_channels,
|
||||||
num_layers=self.args.diff_num_layers,
|
num_layers=self.args.diff_num_layers,
|
||||||
|
@ -454,7 +441,7 @@ class Xtts(BaseTTS):
|
||||||
if librosa_trim_db is not None:
|
if librosa_trim_db is not None:
|
||||||
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
||||||
|
|
||||||
if self.args.use_hifigan or self.args.use_ne_hifigan:
|
if self.args.use_hifigan or self.args.use_hifigan:
|
||||||
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
||||||
else:
|
else:
|
||||||
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
|
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
|
||||||
|
@ -706,19 +693,14 @@ class Xtts(BaseTTS):
|
||||||
break
|
break
|
||||||
|
|
||||||
if decoder == "hifigan":
|
if decoder == "hifigan":
|
||||||
assert hasattr(
|
assert (
|
||||||
self, "hifigan_decoder"
|
hasattr(self, "hifigan_decoder") and self.hifigan_decoder is not None
|
||||||
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||||
elif decoder == "ne_hifigan":
|
|
||||||
assert hasattr(
|
|
||||||
self, "ne_hifigan_decoder"
|
|
||||||
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
|
||||||
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
|
|
||||||
else:
|
else:
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
self, "diffusion_decoder"
|
self, "diffusion_decoder"
|
||||||
), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
|
), "You must disable hifigan decoders to use difffusion by setting `use_hifigan: false`"
|
||||||
mel = do_spectrogram_diffusion(
|
mel = do_spectrogram_diffusion(
|
||||||
self.diffusion_decoder,
|
self.diffusion_decoder,
|
||||||
diffuser,
|
diffuser,
|
||||||
|
@ -816,11 +798,6 @@ class Xtts(BaseTTS):
|
||||||
self, "hifigan_decoder"
|
self, "hifigan_decoder"
|
||||||
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||||
elif decoder == "ne_hifigan":
|
|
||||||
assert hasattr(
|
|
||||||
self, "ne_hifigan_decoder"
|
|
||||||
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
|
||||||
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Diffusion for streaming inference not implemented.")
|
raise NotImplementedError("Diffusion for streaming inference not implemented.")
|
||||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||||
|
@ -850,9 +827,8 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
def get_compatible_checkpoint_state_dict(self, model_path):
|
def get_compatible_checkpoint_state_dict(self, model_path):
|
||||||
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
||||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
|
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else []
|
||||||
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
||||||
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
|
|
||||||
# remove xtts gpt trainer extra keys
|
# remove xtts gpt trainer extra keys
|
||||||
ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
|
ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
|
||||||
for key in list(checkpoint.keys()):
|
for key in list(checkpoint.keys()):
|
||||||
|
@ -915,8 +891,6 @@ class Xtts(BaseTTS):
|
||||||
if eval:
|
if eval:
|
||||||
if hasattr(self, "hifigan_decoder"):
|
if hasattr(self, "hifigan_decoder"):
|
||||||
self.hifigan_decoder.eval()
|
self.hifigan_decoder.eval()
|
||||||
if hasattr(self, "ne_hifigan_decoder"):
|
|
||||||
self.ne_hifigan_decoder.eval()
|
|
||||||
if hasattr(self, "diffusion_decoder"):
|
if hasattr(self, "diffusion_decoder"):
|
||||||
self.diffusion_decoder.eval()
|
self.diffusion_decoder.eval()
|
||||||
if hasattr(self, "vocoder"):
|
if hasattr(self, "vocoder"):
|
||||||
|
|
Loading…
Reference in New Issue