diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 1b80035a..aacdd6f1 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import torch import torch.nn.functional as F import torchaudio +import librosa from coqpit import Coqpit from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel @@ -34,12 +35,8 @@ def load_audio(audiopath, sr=22050): """ audio, sampling_rate = torchaudio.load(audiopath) - if len(audio.shape) > 1: - if audio.shape[0] < 5: - audio = audio[0] - else: - assert audio.shape[1] < 5 - audio = audio[:, 0] + if audio.shape[0] > 1: + audio = audio.mean(0, keepdim=True) if sampling_rate != sr: resampler = torchaudio.transforms.Resample(sampling_rate, sr) @@ -376,7 +373,7 @@ class Xtts(BaseTTS): return next(self.parameters()).device @torch.inference_mode() - def get_gpt_cond_latents(self, audio_path: str, length: int = 3): + def get_gpt_cond_latents(self, audio, sr, length: int = 3): """Compute the conditioning latents for the GPT model from the given audio. Args: @@ -384,24 +381,21 @@ class Xtts(BaseTTS): length (int): Length of the audio in seconds. Defaults to 3. """ - audio = load_audio(audio_path) - audio = audio[:, : 22050 * length] - mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu()) + audio_22k = torchaudio.functional.resample(audio, sr, 22050) + audio_22k = audio_22k[:, : 22050 * length] + mel = wav_to_mel_cloning(audio_22k, mel_norms=self.mel_stats.cpu()) cond_latent = self.gpt.get_style_emb(mel.to(self.device)) return cond_latent.transpose(1, 2) @torch.inference_mode() - def get_diffusion_cond_latents( - self, - audio_path, - ): + def get_diffusion_cond_latents(self, audio, sr): from math import ceil diffusion_conds = [] CHUNK_SIZE = 102400 - audio = load_audio(audio_path, 24000) - for chunk in range(ceil(audio.shape[1] / CHUNK_SIZE)): - current_sample = audio[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE] + audio_24k = torchaudio.functional.resample(audio, sr, 24000) + for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)): + current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE] current_sample = pad_or_truncate(current_sample, CHUNK_SIZE) cond_mel = wav_to_univnet_mel( current_sample.to(self.device), @@ -414,27 +408,38 @@ class Xtts(BaseTTS): return diffusion_latent @torch.inference_mode() - def get_speaker_embedding(self, audio_path): - audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"]) - speaker_embedding = ( - self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True) - .unsqueeze(-1) - .to(self.device) - ) - return speaker_embedding - + def get_speaker_embedding(self, audio, sr): + audio_16k = torchaudio.functional.resample(audio, sr, 16000) + return self.hifigan_decoder.speaker_encoder.forward( + audio_16k.to(self.device), l2_norm=True + ).unsqueeze(-1).to(self.device) + + @torch.inference_mode() def get_conditioning_latents( self, audio_path, - gpt_cond_len=3, - ): + gpt_cond_len=6, + max_ref_length=10, + librosa_trim_db=None, + sound_norm_refs=False, + ): speaker_embedding = None diffusion_cond_latents = None + + audio, sr = torchaudio.load(audio_path) + audio = audio[:, : sr * max_ref_length].to(self.device) + if audio.shape[0] > 1: + audio = audio.mean(0, keepdim=True) + if sound_norm_refs: + audio = (audio / torch.abs(audio).max()) * 0.75 + if librosa_trim_db is not None: + audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] + if self.args.use_hifigan or self.args.use_ne_hifigan: - speaker_embedding = self.get_speaker_embedding(audio_path) + speaker_embedding = self.get_speaker_embedding(audio, sr) else: - diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path) - gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T] + diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr) + gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T] return gpt_cond_latents, diffusion_cond_latents, speaker_embedding def synthesize(self, text, config, speaker_wav, language, **kwargs): @@ -494,7 +499,7 @@ class Xtts(BaseTTS): repetition_penalty=2.0, top_k=50, top_p=0.85, - gpt_cond_len=4, + gpt_cond_len=6, do_sample=True, # Decoder inference decoder_iterations=100, @@ -531,7 +536,7 @@ class Xtts(BaseTTS): (aka boring) outputs. Defaults to 0.8. gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used - else the first `gpt_cond_len` secs is used. Defaults to 3 seconds. + else the first `gpt_cond_len` secs is used. Defaults to 6 seconds. decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine the output, which should theoretically mean a higher quality output.