mirror of https://github.com/coqui-ai/TTS.git
Add support for ne_hifigan
This commit is contained in:
parent
747f688dc3
commit
f3a773991e
|
@ -239,6 +239,7 @@ class XttsArgs(Coqpit):
|
|||
decoder_checkpoint: str = None
|
||||
num_chars: int = 255
|
||||
use_hifigan: bool = True
|
||||
use_ne_hifigan: bool = False
|
||||
|
||||
# XTTS GPT Encoder params
|
||||
tokenizer_file: str = ""
|
||||
|
@ -311,7 +312,7 @@ class Xtts(BaseTTS):
|
|||
def init_models(self):
|
||||
"""Initialize the models. We do it here since we need to load the tokenizer first."""
|
||||
if self.tokenizer.tokenizer is not None:
|
||||
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
|
||||
self.args.gpt_number_text_tokens = max(self.tokenizer.tokenizer.get_vocab().values()) + 1
|
||||
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
|
||||
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
|
||||
|
||||
|
@ -343,7 +344,18 @@ class Xtts(BaseTTS):
|
|||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||
)
|
||||
|
||||
else:
|
||||
if self.args.use_ne_hifigan:
|
||||
self.ne_hifigan_decoder = HifiDecoder(
|
||||
input_sample_rate=self.args.input_sample_rate,
|
||||
output_sample_rate=self.args.output_sample_rate,
|
||||
output_hop_length=self.args.output_hop_length,
|
||||
ar_mel_length_compression=self.args.ar_mel_length_compression,
|
||||
decoder_input_dim=self.args.decoder_input_dim,
|
||||
d_vector_dim=self.args.d_vector_dim,
|
||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||
)
|
||||
|
||||
if not (self.args.use_hifigan or self.args.use_ne_hifigan):
|
||||
self.diffusion_decoder = DiffusionTts(
|
||||
model_channels=self.args.diff_model_channels,
|
||||
num_layers=self.args.diff_num_layers,
|
||||
|
@ -491,6 +503,7 @@ class Xtts(BaseTTS):
|
|||
cond_free_k=2,
|
||||
diffusion_temperature=1.0,
|
||||
decoder_sampler="ddim",
|
||||
decoder="hifigan",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -539,6 +552,9 @@ class Xtts(BaseTTS):
|
|||
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
|
||||
Defaults to 1.0.
|
||||
|
||||
decoder: (str) Selects the decoder to use between ("hifigan", "ne_hifigan" and "diffusion")
|
||||
Defaults to hifigan
|
||||
|
||||
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
|
||||
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
|
||||
here: https://huggingface.co/docs/transformers/internal/generation_utils
|
||||
|
@ -569,6 +585,7 @@ class Xtts(BaseTTS):
|
|||
cond_free_k=cond_free_k,
|
||||
diffusion_temperature=diffusion_temperature,
|
||||
decoder_sampler=decoder_sampler,
|
||||
decoder=decoder,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
|
@ -593,6 +610,7 @@ class Xtts(BaseTTS):
|
|||
cond_free_k=2,
|
||||
diffusion_temperature=1.0,
|
||||
decoder_sampler="ddim",
|
||||
decoder="hifigan",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
text = f"[{language}]{text.strip().lower()}"
|
||||
|
@ -649,9 +667,14 @@ class Xtts(BaseTTS):
|
|||
gpt_latents = gpt_latents[:, :k]
|
||||
break
|
||||
|
||||
if self.args.use_hifigan:
|
||||
if decoder == "hifigan":
|
||||
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
elif decoder == "ne_hifigan":
|
||||
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
else:
|
||||
assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
|
||||
mel = do_spectrogram_diffusion(
|
||||
self.diffusion_decoder,
|
||||
diffuser,
|
||||
|
@ -695,6 +718,7 @@ class Xtts(BaseTTS):
|
|||
top_p=0.85,
|
||||
do_sample=True,
|
||||
# Decoder inference
|
||||
decoder="hifigan",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
||||
|
@ -736,7 +760,14 @@ class Xtts(BaseTTS):
|
|||
|
||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
if decoder == "hifigan":
|
||||
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
elif decoder == "ne_hifigan":
|
||||
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
else:
|
||||
raise NotImplementedError("Diffusion for streaming inference not implemented.")
|
||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||
)
|
||||
|
@ -794,7 +825,9 @@ class Xtts(BaseTTS):
|
|||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
||||
|
||||
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"]
|
||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
|
||||
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
||||
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
|
||||
for key in list(checkpoint.keys()):
|
||||
if key.split(".")[0] in ignore_keys:
|
||||
del checkpoint[key]
|
||||
|
@ -802,6 +835,7 @@ class Xtts(BaseTTS):
|
|||
|
||||
if eval:
|
||||
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
|
||||
if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval()
|
||||
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
|
||||
if hasattr(self, "vocoder"): self.vocoder.eval()
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
||||
|
|
Loading…
Reference in New Issue