diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 2b480744..784ba1be 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -239,6 +239,7 @@ class XttsArgs(Coqpit): decoder_checkpoint: str = None num_chars: int = 255 use_hifigan: bool = True + use_ne_hifigan: bool = False # XTTS GPT Encoder params tokenizer_file: str = "" @@ -311,7 +312,7 @@ class Xtts(BaseTTS): def init_models(self): """Initialize the models. We do it here since we need to load the tokenizer first.""" if self.tokenizer.tokenizer is not None: - self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size() + self.args.gpt_number_text_tokens = max(self.tokenizer.tokenizer.get_vocab().values()) + 1 self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]") self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]") @@ -343,7 +344,18 @@ class Xtts(BaseTTS): cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, ) - else: + if self.args.use_ne_hifigan: + self.ne_hifigan_decoder = HifiDecoder( + input_sample_rate=self.args.input_sample_rate, + output_sample_rate=self.args.output_sample_rate, + output_hop_length=self.args.output_hop_length, + ar_mel_length_compression=self.args.ar_mel_length_compression, + decoder_input_dim=self.args.decoder_input_dim, + d_vector_dim=self.args.d_vector_dim, + cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, + ) + + if not (self.args.use_hifigan or self.args.use_ne_hifigan): self.diffusion_decoder = DiffusionTts( model_channels=self.args.diff_model_channels, num_layers=self.args.diff_num_layers, @@ -491,6 +503,7 @@ class Xtts(BaseTTS): cond_free_k=2, diffusion_temperature=1.0, decoder_sampler="ddim", + decoder="hifigan", **hf_generate_kwargs, ): """ @@ -539,6 +552,9 @@ class Xtts(BaseTTS): Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared. Defaults to 1.0. + decoder: (str) Selects the decoder to use between ("hifigan", "ne_hifigan" and "diffusion") + Defaults to hifigan + hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation here: https://huggingface.co/docs/transformers/internal/generation_utils @@ -569,6 +585,7 @@ class Xtts(BaseTTS): cond_free_k=cond_free_k, diffusion_temperature=diffusion_temperature, decoder_sampler=decoder_sampler, + decoder=decoder, **hf_generate_kwargs, ) @@ -593,6 +610,7 @@ class Xtts(BaseTTS): cond_free_k=2, diffusion_temperature=1.0, decoder_sampler="ddim", + decoder="hifigan", **hf_generate_kwargs, ): text = f"[{language}]{text.strip().lower()}" @@ -649,9 +667,14 @@ class Xtts(BaseTTS): gpt_latents = gpt_latents[:, :k] break - if self.args.use_hifigan: + if decoder == "hifigan": + assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) + elif decoder == "ne_hifigan": + assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" + wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding) else: + assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`" mel = do_spectrogram_diffusion( self.diffusion_decoder, diffuser, @@ -695,6 +718,7 @@ class Xtts(BaseTTS): top_p=0.85, do_sample=True, # Decoder inference + decoder="hifigan", **hf_generate_kwargs, ): assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." @@ -736,7 +760,14 @@ class Xtts(BaseTTS): if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): gpt_latents = torch.cat(all_latents, dim=0)[None, :] - wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + if decoder == "hifigan": + assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + elif decoder == "ne_hifigan": + assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" + wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + else: + raise NotImplementedError("Diffusion for streaming inference not implemented.") wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len ) @@ -794,7 +825,9 @@ class Xtts(BaseTTS): self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] - ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"] + ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else [] + ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"] + ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"] for key in list(checkpoint.keys()): if key.split(".")[0] in ignore_keys: del checkpoint[key] @@ -802,6 +835,7 @@ class Xtts(BaseTTS): if eval: if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval() + if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval() if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval() if hasattr(self, "vocoder"): self.vocoder.eval() self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)