import os from contextlib import contextmanager from dataclasses import dataclass import torch import torch.nn.functional as F import torchaudio from coqpit import Coqpit from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.vocoder import UnivNetGenerator from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec init_stream_support() def load_audio(audiopath, sr=22050): """ Load an audio file from disk and resample it to the specified sampling rate. Args: audiopath (str): Path to the audio file. sr (int): Target sampling rate. Returns: Tensor: Audio waveform tensor with shape (1, T), where T is the number of samples. """ audio, sampling_rate = torchaudio.load(audiopath) if len(audio.shape) > 1: if audio.shape[0] < 5: audio = audio[0] else: assert audio.shape[1] < 5 audio = audio[:, 0] if sampling_rate != sr: resampler = torchaudio.transforms.Resample(sampling_rate, sr) audio = resampler(audio) audio = audio.clamp_(-1, 1) return audio.unsqueeze(0) def wav_to_mel_cloning( wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu") ): """ Convert waveform to mel-spectrogram with hard-coded parameters for cloning. Args: wav (torch.Tensor): Input waveform tensor. mel_norms_file (str): Path to mel-spectrogram normalization file. mel_norms (torch.Tensor): Mel-spectrogram normalization tensor. device (torch.device): Device to use for computation. Returns: torch.Tensor: Mel-spectrogram tensor. """ mel_stft = torchaudio.transforms.MelSpectrogram( n_fft=4096, hop_length=1024, win_length=4096, power=2, normalized=False, sample_rate=22050, f_min=0, f_max=8000, n_mels=80, norm="slaney", ).to(device) wav = wav.to(device) mel = mel_stft(wav) mel = torch.log(torch.clamp(mel, min=1e-5)) if mel_norms is None: mel_norms = torch.load(mel_norms_file, map_location=device) mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) return mel def pad_or_truncate(t, length): """ Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it. Args: t (torch.Tensor): The input tensor to be padded or truncated. length (int): The desired length of the tensor. Returns: torch.Tensor: The padded or truncated tensor. """ tp = t[..., :length] if t.shape[-1] == length: tp = t elif t.shape[-1] < length: tp = F.pad(t, (0, length - t.shape[-1])) return tp def load_discrete_vocoder_diffuser( trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1, sampler="ddim", ): """ Load a GaussianDiffusion instance configured for use as a decoder. Args: trained_diffusion_steps (int): The number of diffusion steps used during training. desired_diffusion_steps (int): The number of diffusion steps to use during inference. cond_free (bool): Whether to use a conditioning-free model. cond_free_k (int): The number of samples to use for conditioning-free models. sampler (str): The name of the sampler to use. Returns: A SpacedDiffusion instance configured with the given parameters. """ return SpacedDiffusion( use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type="epsilon", model_var_type="learned_range", loss_type="mse", betas=get_named_beta_schedule("linear", trained_diffusion_steps), conditioning_free=cond_free, conditioning_free_k=cond_free_k, sampler=sampler, ) def do_spectrogram_diffusion( diffusion_model, diffuser, latents, conditioning_latents, temperature=1, ): """ Generate a mel-spectrogram using a diffusion model and a diffuser. Args: diffusion_model (nn.Module): A diffusion model that converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. diffuser (Diffuser): A diffuser that generates a mel-spectrogram from noise. latents (torch.Tensor): A tensor of shape (batch_size, seq_len, code_size) containing the input spectrogram codes. conditioning_latents (torch.Tensor): A tensor of shape (batch_size, code_size) containing the conditioning codes. temperature (float, optional): The temperature of the noise used by the diffuser. Defaults to 1. Returns: torch.Tensor: A tensor of shape (batch_size, mel_channels, mel_seq_len) containing the generated mel-spectrogram. """ with torch.no_grad(): output_seq_len = ( latents.shape[1] * 4 * 24000 // 22050 ) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. output_shape = (latents.shape[0], 100, output_seq_len) precomputed_embeddings = diffusion_model.timestep_independent( latents, conditioning_latents, output_seq_len, False ) noise = torch.randn(output_shape, device=latents.device) * temperature mel = diffuser.sample_loop( diffusion_model, output_shape, noise=noise, model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, progress=False, ) return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] @dataclass class XttsAudioConfig(Coqpit): """ Configuration class for audio-related parameters in the XTTS model. Args: sample_rate (int): The sample rate in which the GPT operates. diffusion_sample_rate (int): The sample rate of the diffusion audio waveform. output_sample_rate (int): The sample rate of the output audio waveform. """ sample_rate: int = 22050 diffusion_sample_rate: int = 24000 output_sample_rate: int = 24000 @dataclass class XttsArgs(Coqpit): """A dataclass to represent XTTS model arguments that define the model structure. Args: gpt_batch_size (int): The size of the auto-regressive batch. enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True. For GPT model: ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. ar_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70. ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. ar_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024. ar_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. For DiffTTS model: diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024. diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10. diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100. diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200. diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024. diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193. diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0. diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False. diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16. diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0. diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0. """ gpt_batch_size: int = 1 enable_redaction: bool = False kv_cache: bool = True gpt_checkpoint: str = None clvp_checkpoint: str = None decoder_checkpoint: str = None num_chars: int = 255 use_hifigan: bool = True use_ne_hifigan: bool = False # XTTS GPT Encoder params tokenizer_file: str = "" gpt_max_audio_tokens: int = 605 gpt_max_text_tokens: int = 402 gpt_max_prompt_tokens: int = 70 gpt_layers: int = 30 gpt_n_model_channels: int = 1024 gpt_n_heads: int = 16 gpt_number_text_tokens: int = None gpt_start_text_token: int = None gpt_stop_text_token: int = None gpt_num_audio_tokens: int = 8194 gpt_start_audio_token: int = 8192 gpt_stop_audio_token: int = 8193 # Diffusion Decoder params diff_model_channels: int = 1024 diff_num_layers: int = 10 diff_in_channels: int = 100 diff_out_channels: int = 200 diff_in_latent_channels: int = 1024 diff_in_tokens: int = 8193 diff_dropout: int = 0 diff_use_fp16: bool = False diff_num_heads: int = 16 diff_layer_drop: int = 0 diff_unconditioned_percentage: int = 0 # HifiGAN Decoder params input_sample_rate: int = 22050 output_sample_rate: int = 24000 output_hop_length: int = 256 ar_mel_length_compression: int = 1024 decoder_input_dim: int = 1024 d_vector_dim: int = 512 cond_d_vector_in_each_upsampling_layer: bool = True # constants duration_const: int = 102400 class Xtts(BaseTTS): """ⓍTTS model implementation. ❗ Currently it only supports inference. Examples: >>> from TTS.tts.configs.xtts_config import XttsConfig >>> from TTS.tts.models.xtts import Xtts >>> config = XttsConfig() >>> model = Xtts.inif_from_config(config) >>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) """ def __init__(self, config: Coqpit): super().__init__(config, ap=None, tokenizer=None) self.mel_stats_path = None self.config = config self.gpt_checkpoint = self.args.gpt_checkpoint self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed self.models_dir = config.model_dir self.gpt_batch_size = self.args.gpt_batch_size self.tokenizer = VoiceBpeTokenizer() self.gpt = None self.init_models() self.register_buffer("mel_stats", torch.ones(80)) def init_models(self): """Initialize the models. We do it here since we need to load the tokenizer first.""" if self.tokenizer.tokenizer is not None: self.args.gpt_number_text_tokens = self.tokenizer.get_number_tokens() self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]") self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]") if self.args.gpt_number_text_tokens: self.gpt = GPT( layers=self.args.gpt_layers, model_dim=self.args.gpt_n_model_channels, start_text_token=self.args.gpt_start_text_token, stop_text_token=self.args.gpt_stop_text_token, heads=self.args.gpt_n_heads, max_text_tokens=self.args.gpt_max_text_tokens, max_mel_tokens=self.args.gpt_max_audio_tokens, max_prompt_tokens=self.args.gpt_max_prompt_tokens, number_text_tokens=self.args.gpt_number_text_tokens, num_audio_tokens=self.args.gpt_num_audio_tokens, start_audio_token=self.args.gpt_start_audio_token, stop_audio_token=self.args.gpt_stop_audio_token, ) if self.args.use_hifigan: self.hifigan_decoder = HifiDecoder( input_sample_rate=self.args.input_sample_rate, output_sample_rate=self.args.output_sample_rate, output_hop_length=self.args.output_hop_length, ar_mel_length_compression=self.args.ar_mel_length_compression, decoder_input_dim=self.args.decoder_input_dim, d_vector_dim=self.args.d_vector_dim, cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, ) if self.args.use_ne_hifigan: self.ne_hifigan_decoder = HifiDecoder( input_sample_rate=self.args.input_sample_rate, output_sample_rate=self.args.output_sample_rate, output_hop_length=self.args.output_hop_length, ar_mel_length_compression=self.args.ar_mel_length_compression, decoder_input_dim=self.args.decoder_input_dim, d_vector_dim=self.args.d_vector_dim, cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, ) if not (self.args.use_hifigan or self.args.use_ne_hifigan): self.diffusion_decoder = DiffusionTts( model_channels=self.args.diff_model_channels, num_layers=self.args.diff_num_layers, in_channels=self.args.diff_in_channels, out_channels=self.args.diff_out_channels, in_latent_channels=self.args.diff_in_latent_channels, in_tokens=self.args.diff_in_tokens, dropout=self.args.diff_dropout, use_fp16=self.args.diff_use_fp16, num_heads=self.args.diff_num_heads, layer_drop=self.args.diff_layer_drop, unconditioned_percentage=self.args.diff_unconditioned_percentage, ) self.vocoder = UnivNetGenerator() @property def device(self): return next(self.parameters()).device @torch.inference_mode() def get_gpt_cond_latents(self, audio_path: str, length: int = 3): """Compute the conditioning latents for the GPT model from the given audio. Args: audio_path (str): Path to the audio file. length (int): Length of the audio in seconds. Defaults to 3. """ audio = load_audio(audio_path) audio = audio[:, : 22050 * length] mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu()) cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False) return cond_latent.transpose(1, 2) @torch.inference_mode() def get_diffusion_cond_latents( self, audio_path, ): from math import ceil diffusion_conds = [] CHUNK_SIZE = 102400 audio = load_audio(audio_path, 24000) for chunk in range(ceil(audio.shape[1] / CHUNK_SIZE)): current_sample = audio[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE] current_sample = pad_or_truncate(current_sample, CHUNK_SIZE) cond_mel = wav_to_univnet_mel( current_sample.to(self.device), do_normalization=False, device=self.device, ) diffusion_conds.append(cond_mel) diffusion_conds = torch.stack(diffusion_conds, dim=1) diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds) return diffusion_latent @torch.inference_mode() def get_speaker_embedding( self, audio_path ): audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"]) speaker_embedding = self.hifigan_decoder.speaker_encoder.forward( audio.to(self.device), l2_norm=True ).unsqueeze(-1).to(self.device) return speaker_embedding def get_conditioning_latents( self, audio_path, gpt_cond_len=3, ): speaker_embedding = None diffusion_cond_latents = None if self.args.use_hifigan: speaker_embedding = self.get_speaker_embedding(audio_path) else: diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path) gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T] return gpt_cond_latents, diffusion_cond_latents, speaker_embedding def synthesize(self, text, config, speaker_wav, language, **kwargs): """Synthesize speech with the given input text. Args: text (str): Input text. config (XttsConfig): Config with inference parameters. speaker_wav (str): Path to the speaker audio file for cloning. language (str): Language ID of the speaker. **kwargs: Inference settings. See `inference()`. Returns: A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` as latents used at inference. """ # Make the synthesizer happy 🥳 if isinstance(speaker_wav, list): speaker_wav = speaker_wav[0] return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs) def inference_with_config(self, text, config, ref_audio_path, language, **kwargs): """ inference with config """ assert ( language in self.config.languages ), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}" # Use generally found best tuning knobs for generation. settings = { "temperature": config.temperature, "length_penalty": config.length_penalty, "repetition_penalty": config.repetition_penalty, "top_k": config.top_k, "top_p": config.top_p, "cond_free_k": config.cond_free_k, "diffusion_temperature": config.diffusion_temperature, "decoder_iterations": config.decoder_iterations, "decoder_sampler": config.decoder_sampler, } settings.update(kwargs) # allow overriding of preset settings with kwargs return self.full_inference(text, ref_audio_path, language, **settings) @torch.inference_mode() def full_inference( self, text, ref_audio_path, language, # GPT inference temperature=0.65, length_penalty=1, repetition_penalty=2.0, top_k=50, top_p=0.85, gpt_cond_len=4, do_sample=True, # Decoder inference decoder_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, decoder_sampler="ddim", decoder="hifigan", **hf_generate_kwargs, ): """ This function produces an audio clip of the given text being spoken with the given reference voice. Args: text: (str) Text to be spoken. ref_audio_path: (str) Path to a reference audio file to be used for cloning. This audio file should be >3 seconds long. language: (str) Language of the voice to be generated. temperature: (float) The softmax temperature of the autoregressive model. Defaults to 0.65. length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. Defaults to 1.0. repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence of long silences or "uhhhhhhs", etc. Defaults to 2.0. top_k: (int) K value used in top-k sampling. [0,inf]. Lower values mean the decoder produces more "likely" (aka boring) outputs. Defaults to 50. top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. Defaults to 0.8. gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used else the first `gpt_cond_len` secs is used. Defaults to 3 seconds. decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, however. Defaults to 100. cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and dramatically improves realism. Defaults to True. cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. As cond_free_k increases, the output becomes dominated by the conditioning-free signal. Defaults to 2.0. diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared. Defaults to 1.0. decoder: (str) Selects the decoder to use between ("hifigan", "ne_hifigan" and "diffusion") Defaults to hifigan hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation here: https://huggingface.co/docs/transformers/internal/generation_utils Returns: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Sample rate is 24kHz. """ ( gpt_cond_latent, diffusion_conditioning, speaker_embedding ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len) return self.inference( text, language, gpt_cond_latent, speaker_embedding, diffusion_conditioning, temperature=temperature, length_penalty=length_penalty, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, do_sample=do_sample, decoder_iterations=decoder_iterations, cond_free=cond_free, cond_free_k=cond_free_k, diffusion_temperature=diffusion_temperature, decoder_sampler=decoder_sampler, decoder=decoder, **hf_generate_kwargs, ) @torch.inference_mode() def inference( self, text, language, gpt_cond_latent, speaker_embedding, diffusion_conditioning, # GPT inference temperature=0.65, length_penalty=1, repetition_penalty=2.0, top_k=50, top_p=0.85, do_sample=True, # Decoder inference decoder_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, decoder_sampler="ddim", decoder="hifigan", **hf_generate_kwargs, ): text = f"[{language}]{text.strip().lower()}" text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) assert ( text_tokens.shape[-1] < self.args.gpt_max_text_tokens ), " ❗ XTTS can only generate text with a maximum of 400 tokens." if not self.args.use_hifigan: diffuser = load_discrete_vocoder_diffuser( desired_diffusion_steps=decoder_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=decoder_sampler, ) with torch.no_grad(): gpt_codes = self.gpt.generate( cond_latents=gpt_cond_latent, text_inputs=text_tokens, input_tokens=None, do_sample=do_sample, top_p=top_p, top_k=top_k, temperature=temperature, num_return_sequences=self.gpt_batch_size, length_penalty=length_penalty, repetition_penalty=repetition_penalty, output_attentions=False, **hf_generate_kwargs, ) expected_output_len = torch.tensor( [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device ) text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) gpt_latents = self.gpt( text_tokens, text_len, gpt_codes, expected_output_len, cond_latents=gpt_cond_latent, return_attentions=False, return_latent=True, ) silence_token = 83 ctokens = 0 for k in range(gpt_codes.shape[-1]): if gpt_codes[0, k] == silence_token: ctokens += 1 else: ctokens = 0 if ctokens > 8: gpt_latents = gpt_latents[:, :k] break if decoder == "hifigan": assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) elif decoder == "ne_hifigan": assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding) else: assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`" mel = do_spectrogram_diffusion( self.diffusion_decoder, diffuser, gpt_latents, diffusion_conditioning, temperature=diffusion_temperature, ) wav = self.vocoder.inference(mel) return {"wav": wav.cpu().numpy().squeeze()} def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): """Handle chunk formatting in streaming mode""" wav_chunk = wav_gen[:-overlap_len] if wav_gen_prev is not None: wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] if wav_overlap is not None: crossfade_wav = wav_chunk[:overlap_len] crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) wav_chunk[:overlap_len] += crossfade_wav wav_overlap = wav_gen[-overlap_len:] wav_gen_prev = wav_gen return wav_chunk, wav_gen_prev, wav_overlap @torch.inference_mode() def inference_stream( self, text, language, gpt_cond_latent, speaker_embedding, # Streaming stream_chunk_size=20, overlap_wav_len=1024, # GPT inference temperature=0.65, length_penalty=1, repetition_penalty=2.0, top_k=50, top_p=0.85, do_sample=True, # Decoder inference decoder="hifigan", **hf_generate_kwargs, ): assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." text = f"[{language}]{text.strip().lower()}" text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) fake_inputs = self.gpt.compute_embeddings( gpt_cond_latent.to(self.device), text_tokens, ) gpt_generator = self.gpt.get_generator( fake_inputs=fake_inputs, top_k=top_k, top_p=top_p, temperature=temperature, do_sample=do_sample, num_beams=1, num_return_sequences=1, length_penalty=float(length_penalty), repetition_penalty=float(repetition_penalty), output_attentions=False, output_hidden_states=True, **hf_generate_kwargs, ) last_tokens = [] all_latents = [] wav_gen_prev = None wav_overlap = None is_end = False while not is_end: try: x, latent = next(gpt_generator) last_tokens += [x] all_latents += [latent] except StopIteration: is_end = True if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): gpt_latents = torch.cat(all_latents, dim=0)[None, :] if decoder == "hifigan": assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) elif decoder == "ne_hifigan": assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) else: raise NotImplementedError("Diffusion for streaming inference not implemented.") wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len ) last_tokens = [] yield wav_chunk def forward(self): raise NotImplementedError("XTTS Training is not implemented") def eval_step(self): raise NotImplementedError("XTTS Training is not implemented") @staticmethod def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument return Xtts(config) def eval(self): # pylint: disable=redefined-builtin """Sets the model to evaluation mode. Overrides the default eval() method to also set the GPT model to eval mode.""" self.gpt.init_gpt_for_inference() super().eval() def load_checkpoint( self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=True, strict=True, use_deepspeed=False, ): """ Loads a checkpoint from disk and initializes the model's state and tokenizer. Args: config (dict): The configuration dictionary for the model. checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None. vocab_path (str, optional): The path to the vocabulary file. Defaults to None. eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True. strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True. Returns: None """ model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json") if os.path.exists(vocab_path): self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path) self.init_models() if eval: self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else [] ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"] ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"] for key in list(checkpoint.keys()): if key.split(".")[0] in ignore_keys: del checkpoint[key] self.load_state_dict(checkpoint, strict=strict) if eval: if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval() if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval() if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval() if hasattr(self, "vocoder"): self.vocoder.eval() self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) self.gpt.eval() def train_step(self): raise NotImplementedError("XTTS Training is not implemented")