diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 8e9d6bd3..7585e4a4 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -691,6 +691,180 @@ class Xtts(BaseTTS): last_tokens = [] yield wav_chunk + def init_for_audio_edit(self, dvae_checkpoint, mel_norm_file): + from TTS.tts.layers.xtts.dvae import DiscreteVAE + from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram + + dvae_sample_rate: int = 22050 + + # Load DVAE + self.dvae = DiscreteVAE( + channels=80, + normalization=None, + positional_dims=1, + num_tokens=self.args.gpt_num_audio_tokens - 2, + codebook_dim=512, + hidden_dim=512, + num_resnet_blocks=3, + kernel_size=3, + num_layers=2, + use_transposed_convs=False, + ) + + self.dvae.eval() + dvae_dict = torch.load(dvae_checkpoint, map_location=torch.device("cpu")) + self.dvae.load_state_dict(dvae_dict, strict=False) + print(">> DVAE weights restored from:", dvae_checkpoint) + + # Mel spectrogram extractor for DVAE + self.torch_mel_spectrogram_dvae = TorchMelSpectrogram( + mel_norm_file=mel_norm_file, sampling_rate=dvae_sample_rate + ) + + self.dvae = self.dvae.to(self.device) + self.torch_mel_spectrogram_dvae = self.torch_mel_spectrogram_dvae.to(self.device) + + @torch.inference_mode() + def audio_edit( + self, + text_left, + text_edit, + text_right, + audio_left_path, + audio_right_path, + language, + gpt_cond_latent, + speaker_embedding, + generate_times=5, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + num_beams=1, + speed=1.0, + enable_text_splitting=False, + **hf_generate_kwargs,): + + load_sr: int = 22050 + audio_left = load_audio(audio_left_path, load_sr) + audio_left = audio_left.to(self.device) + audio_right = load_audio(audio_right_path, load_sr) + audio_right = audio_right.to(self.device) + + dvae_mel_spec_left = self.torch_mel_spectrogram_dvae(audio_left) + codes_left = self.dvae.get_codebook_indices(dvae_mel_spec_left) + dvae_mel_spec_right = self.torch_mel_spectrogram_dvae(audio_right) + codes_right = self.dvae.get_codebook_indices(dvae_mel_spec_right) + + language = language.split("-")[0] # remove the country code + length_scale = 1.0 / max(speed, 0.05) + gpt_cond_latent = gpt_cond_latent.to(self.device) + speaker_embedding = speaker_embedding.to(self.device) + + wavs = [] + with torch.no_grad(): + sent = text_left + text_edit + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + + sent_left = text_left.strip().lower() + text_left_tokens = torch.IntTensor(self.tokenizer.encode(sent_left, lang=language)).unsqueeze(0).to(self.device) + + sent_right = text_edit + text_right + sent_right = sent_right.strip().lower() + text_right_tokens = torch.IntTensor(self.tokenizer.encode(sent_right, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + gpt_inputs = self.gpt.compute_embeddings(gpt_cond_latent, text_tokens) + # add codes_left to inputs + gpt_inputs = torch.cat([gpt_inputs, codes_left], dim=1) + + best_gpt_latents = None + best_logits = 0 + for _ in range(generate_times): + gen = self.gpt.gpt_inference.generate( + gpt_inputs, + bos_token_id=self.gpt.start_audio_token, + pad_token_id=self.gpt.stop_audio_token, + eos_token_id=self.gpt.stop_audio_token, + max_length=self.gpt.max_gen_mel_tokens + gpt_inputs.shape[-1], + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=self.gpt_batch_size, + num_beams=num_beams, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + **hf_generate_kwargs, + ) + gpt_codes = gen[:, gpt_inputs.shape[1]: -1] + + gen_len = gpt_codes.shape[-1] + + gpt_right_codes = torch.cat([gpt_codes, codes_right], dim=1) + expected_output_len = torch.tensor( + [gpt_right_codes.shape[-1] * self.gpt.code_stride_len], device=text_right_tokens.device + ) + text_len = torch.tensor([text_right_tokens.shape[-1]], device=self.device) + + gpt_latents = self.gpt( + text_right_tokens, + text_len, + gpt_right_codes, + expected_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) + + right_logits = torch.gather(F.softmax(gpt_latents[:, gen_len:], dim=-1), + -1, + codes_right.unsqueeze(0).transpose(-1, -2)) + sum_logits = torch.sum(right_logits, dim=-2) + sum_logits = float(sum_logits[0,0]) + + if sum_logits > best_logits: + best_logits = sum_logits + best_gpt_latents = gpt_latents.detach() + + + expected_left_len = torch.tensor( + [codes_left.shape[-1] * self.gpt.code_stride_len], device=text_left_tokens.device + ) + text_left_len = torch.tensor([text_left_tokens.shape[-1]], device=self.device) + gpt_left_latents = self.gpt( + text_left_tokens, + text_left_len, + codes_left, + expected_left_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) + + # if length_scale != 1.0: + # best_gpt_latents = F.interpolate( + # best_gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" + # ).transpose(1, 2) + + # wavs.append(load_audio(audio_left_path, self.args.output_sample_rate).cpu().squeeze()) + # wavs.append(self.hifigan_decoder(best_gpt_latents, g=speaker_embedding).cpu().squeeze()) + # wavs.append(load_audio(audio_right_path, self.args.output_sample_rate).cpu().squeeze()) + wavs.append(self.hifigan_decoder(torch.cat([gpt_left_latents, best_gpt_latents], dim=1), + g=speaker_embedding).cpu().squeeze()) + + return { + "wav": torch.cat(wavs, dim=0).numpy(), + } + def forward(self): raise NotImplementedError( "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"