Add support for ne_hifigan

This commit is contained in:
WeberJulian 2023-10-19 17:19:13 -03:00
parent 747f688dc3
commit f3a773991e
1 changed files with 39 additions and 5 deletions

View File

@ -239,6 +239,7 @@ class XttsArgs(Coqpit):
decoder_checkpoint: str = None
num_chars: int = 255
use_hifigan: bool = True
use_ne_hifigan: bool = False
# XTTS GPT Encoder params
tokenizer_file: str = ""
@ -311,7 +312,7 @@ class Xtts(BaseTTS):
def init_models(self):
"""Initialize the models. We do it here since we need to load the tokenizer first."""
if self.tokenizer.tokenizer is not None:
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
self.args.gpt_number_text_tokens = max(self.tokenizer.tokenizer.get_vocab().values()) + 1
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
@ -343,7 +344,18 @@ class Xtts(BaseTTS):
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
else:
if self.args.use_ne_hifigan:
self.ne_hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate,
output_sample_rate=self.args.output_sample_rate,
output_hop_length=self.args.output_hop_length,
ar_mel_length_compression=self.args.ar_mel_length_compression,
decoder_input_dim=self.args.decoder_input_dim,
d_vector_dim=self.args.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
if not (self.args.use_hifigan or self.args.use_ne_hifigan):
self.diffusion_decoder = DiffusionTts(
model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers,
@ -491,6 +503,7 @@ class Xtts(BaseTTS):
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
decoder="hifigan",
**hf_generate_kwargs,
):
"""
@ -539,6 +552,9 @@ class Xtts(BaseTTS):
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to 1.0.
decoder: (str) Selects the decoder to use between ("hifigan", "ne_hifigan" and "diffusion")
Defaults to hifigan
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
here: https://huggingface.co/docs/transformers/internal/generation_utils
@ -569,6 +585,7 @@ class Xtts(BaseTTS):
cond_free_k=cond_free_k,
diffusion_temperature=diffusion_temperature,
decoder_sampler=decoder_sampler,
decoder=decoder,
**hf_generate_kwargs,
)
@ -593,6 +610,7 @@ class Xtts(BaseTTS):
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
decoder="hifigan",
**hf_generate_kwargs,
):
text = f"[{language}]{text.strip().lower()}"
@ -649,9 +667,14 @@ class Xtts(BaseTTS):
gpt_latents = gpt_latents[:, :k]
break
if self.args.use_hifigan:
if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
elif decoder == "ne_hifigan":
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
else:
assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
mel = do_spectrogram_diffusion(
self.diffusion_decoder,
diffuser,
@ -695,6 +718,7 @@ class Xtts(BaseTTS):
top_p=0.85,
do_sample=True,
# Decoder inference
decoder="hifigan",
**hf_generate_kwargs,
):
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
@ -736,7 +760,14 @@ class Xtts(BaseTTS):
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
elif decoder == "ne_hifigan":
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
else:
raise NotImplementedError("Diffusion for streaming inference not implemented.")
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
)
@ -794,7 +825,9 @@ class Xtts(BaseTTS):
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
for key in list(checkpoint.keys()):
if key.split(".")[0] in ignore_keys:
del checkpoint[key]
@ -802,6 +835,7 @@ class Xtts(BaseTTS):
if eval:
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval()
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
if hasattr(self, "vocoder"): self.vocoder.eval()
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)