Add support for ne_hifigan

This commit is contained in:
WeberJulian 2023-10-19 17:19:13 -03:00
parent 747f688dc3
commit f3a773991e
1 changed files with 39 additions and 5 deletions

View File

@ -239,6 +239,7 @@ class XttsArgs(Coqpit):
decoder_checkpoint: str = None decoder_checkpoint: str = None
num_chars: int = 255 num_chars: int = 255
use_hifigan: bool = True use_hifigan: bool = True
use_ne_hifigan: bool = False
# XTTS GPT Encoder params # XTTS GPT Encoder params
tokenizer_file: str = "" tokenizer_file: str = ""
@ -311,7 +312,7 @@ class Xtts(BaseTTS):
def init_models(self): def init_models(self):
"""Initialize the models. We do it here since we need to load the tokenizer first.""" """Initialize the models. We do it here since we need to load the tokenizer first."""
if self.tokenizer.tokenizer is not None: if self.tokenizer.tokenizer is not None:
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size() self.args.gpt_number_text_tokens = max(self.tokenizer.tokenizer.get_vocab().values()) + 1
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]") self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]") self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
@ -343,7 +344,18 @@ class Xtts(BaseTTS):
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
) )
else: if self.args.use_ne_hifigan:
self.ne_hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate,
output_sample_rate=self.args.output_sample_rate,
output_hop_length=self.args.output_hop_length,
ar_mel_length_compression=self.args.ar_mel_length_compression,
decoder_input_dim=self.args.decoder_input_dim,
d_vector_dim=self.args.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
if not (self.args.use_hifigan or self.args.use_ne_hifigan):
self.diffusion_decoder = DiffusionTts( self.diffusion_decoder = DiffusionTts(
model_channels=self.args.diff_model_channels, model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers, num_layers=self.args.diff_num_layers,
@ -491,6 +503,7 @@ class Xtts(BaseTTS):
cond_free_k=2, cond_free_k=2,
diffusion_temperature=1.0, diffusion_temperature=1.0,
decoder_sampler="ddim", decoder_sampler="ddim",
decoder="hifigan",
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
""" """
@ -539,6 +552,9 @@ class Xtts(BaseTTS):
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared. Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to 1.0. Defaults to 1.0.
decoder: (str) Selects the decoder to use between ("hifigan", "ne_hifigan" and "diffusion")
Defaults to hifigan
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
here: https://huggingface.co/docs/transformers/internal/generation_utils here: https://huggingface.co/docs/transformers/internal/generation_utils
@ -569,6 +585,7 @@ class Xtts(BaseTTS):
cond_free_k=cond_free_k, cond_free_k=cond_free_k,
diffusion_temperature=diffusion_temperature, diffusion_temperature=diffusion_temperature,
decoder_sampler=decoder_sampler, decoder_sampler=decoder_sampler,
decoder=decoder,
**hf_generate_kwargs, **hf_generate_kwargs,
) )
@ -593,6 +610,7 @@ class Xtts(BaseTTS):
cond_free_k=2, cond_free_k=2,
diffusion_temperature=1.0, diffusion_temperature=1.0,
decoder_sampler="ddim", decoder_sampler="ddim",
decoder="hifigan",
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
text = f"[{language}]{text.strip().lower()}" text = f"[{language}]{text.strip().lower()}"
@ -649,9 +667,14 @@ class Xtts(BaseTTS):
gpt_latents = gpt_latents[:, :k] gpt_latents = gpt_latents[:, :k]
break break
if self.args.use_hifigan: if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
elif decoder == "ne_hifigan":
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
else: else:
assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
mel = do_spectrogram_diffusion( mel = do_spectrogram_diffusion(
self.diffusion_decoder, self.diffusion_decoder,
diffuser, diffuser,
@ -695,6 +718,7 @@ class Xtts(BaseTTS):
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
# Decoder inference # Decoder inference
decoder="hifigan",
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
@ -736,7 +760,14 @@ class Xtts(BaseTTS):
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :] gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
elif decoder == "ne_hifigan":
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
else:
raise NotImplementedError("Diffusion for streaming inference not implemented.")
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
) )
@ -794,7 +825,9 @@ class Xtts(BaseTTS):
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"] ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
for key in list(checkpoint.keys()): for key in list(checkpoint.keys()):
if key.split(".")[0] in ignore_keys: if key.split(".")[0] in ignore_keys:
del checkpoint[key] del checkpoint[key]
@ -802,6 +835,7 @@ class Xtts(BaseTTS):
if eval: if eval:
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval() if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval()
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval() if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
if hasattr(self, "vocoder"): self.vocoder.eval() if hasattr(self, "vocoder"): self.vocoder.eval()
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)