From a16360af859732eedcb6c0faaa1a57081c33c9be Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 13 Nov 2023 13:00:08 +0100 Subject: [PATCH] Implement chunking gpt_cond --- TTS/tts/configs/xtts_config.py | 10 +++- TTS/tts/models/xtts.py | 101 +++++++++++++++++++++++---------- 2 files changed, 78 insertions(+), 33 deletions(-) diff --git a/TTS/tts/configs/xtts_config.py b/TTS/tts/configs/xtts_config.py index 2d3edaf4..e8ab07da 100644 --- a/TTS/tts/configs/xtts_config.py +++ b/TTS/tts/configs/xtts_config.py @@ -43,7 +43,12 @@ class XttsConfig(BaseTTSConfig): Defaults to `16`. gpt_cond_len (int): - Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`. + Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`. + + gpt_cond_chunk_len (int): + Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the + latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len. + If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`. max_ref_len (int): Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`. @@ -95,6 +100,7 @@ class XttsConfig(BaseTTSConfig): num_gpt_outputs: int = 1 # cloning - gpt_cond_len: int = 3 + gpt_cond_len: int = 12 + gpt_cond_chunk_len: int = 4 max_ref_len: int = 10 sound_norm_refs: bool = False diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index f41bcfb9..0f79ad69 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -255,39 +255,57 @@ class Xtts(BaseTTS): return next(self.parameters()).device @torch.inference_mode() - def get_gpt_cond_latents(self, audio, sr, length: int = 3): + def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6): """Compute the conditioning latents for the GPT model from the given audio. Args: audio (tensor): audio tensor. sr (int): Sample rate of the audio. - length (int): Length of the audio in seconds. Defaults to 3. + length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30. + chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio + is being used without chunking. It must be < `length`. Defaults to 6. """ if sr != 22050: audio = torchaudio.functional.resample(audio, sr, 22050) - audio = audio[:, : 22050 * length] + if length > 0: + audio = audio[:, : 22050 * length] if self.args.gpt_use_perceiver_resampler: - n_fft = 2048 - hop_length = 256 - win_length = 1024 + style_embs = [] + for i in range(0, audio.shape[1], 22050 * chunk_length): + audio_chunk = audio[:, i : i + 22050 * chunk_length] + mel_chunk = wav_to_mel_cloning( + audio_chunk, + mel_norms=self.mel_stats.cpu(), + n_fft=2048, + hop_length=256, + win_length=1024, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + ) + style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None) + style_embs.append(style_emb) + + # mean style embedding + cond_latent = torch.stack(style_embs).mean(dim=0) else: - n_fft = 4096 - hop_length = 1024 - win_length = 4096 - mel = wav_to_mel_cloning( - audio, - mel_norms=self.mel_stats.cpu(), - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - power=2, - normalized=False, - sample_rate=22050, - f_min=0, - f_max=8000, - n_mels=80, - ) - cond_latent = self.gpt.get_style_emb(mel.to(self.device)) + mel = wav_to_mel_cloning( + audio, + mel_norms=self.mel_stats.cpu(), + n_fft=4096, + hop_length=1024, + win_length=4096, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + ) + cond_latent = self.gpt.get_style_emb(mel.to(self.device)) return cond_latent.transpose(1, 2) @torch.inference_mode() @@ -323,12 +341,24 @@ class Xtts(BaseTTS): def get_conditioning_latents( self, audio_path, + max_ref_length=30, gpt_cond_len=6, - max_ref_length=10, + gpt_cond_chunk_len=6, librosa_trim_db=None, sound_norm_refs=False, - load_sr=24000, + load_sr=22050, ): + """Get the conditioning latents for the GPT model from the given audio. + + Args: + audio_path (str or List[str]): Path to reference audio file(s). + max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30. + gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6. + gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6. + librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None. + sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False. + load_sr (int, optional): Sample rate to load the audio. Defaults to 24000. + """ # deal with multiples references if not isinstance(audio_path, list): audio_paths = [audio_path] @@ -349,14 +379,17 @@ class Xtts(BaseTTS): if librosa_trim_db is not None: audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] + # compute latents for the decoder speaker_embedding = self.get_speaker_embedding(audio, load_sr) speaker_embeddings.append(speaker_embedding) audios.append(audio) - # use a merge of all references for gpt cond latents + # merge all the audios and compute the latents for the gpt full_audio = torch.cat(audios, dim=-1) - gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T] + gpt_cond_latents = self.get_gpt_cond_latents( + full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len + ) # [1, 1024, T] if speaker_embeddings: speaker_embedding = torch.stack(speaker_embeddings) @@ -397,6 +430,7 @@ class Xtts(BaseTTS): "top_k": config.top_k, "top_p": config.top_p, "gpt_cond_len": config.gpt_cond_len, + "gpt_cond_chunk_len": config.gpt_cond_chunk_len, "max_ref_len": config.max_ref_len, "sound_norm_refs": config.sound_norm_refs, } @@ -417,7 +451,8 @@ class Xtts(BaseTTS): top_p=0.85, do_sample=True, # Cloning - gpt_cond_len=6, + gpt_cond_len=30, + gpt_cond_chunk_len=6, max_ref_len=10, sound_norm_refs=False, **hf_generate_kwargs, @@ -448,7 +483,10 @@ class Xtts(BaseTTS): (aka boring) outputs. Defaults to 0.8. gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used - else the first `gpt_cond_len` secs is used. Defaults to 6 seconds. + else the first `gpt_cond_len` secs is used. Defaults to 30 seconds. + + gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`. + If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds. hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation @@ -461,6 +499,7 @@ class Xtts(BaseTTS): (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents( audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len, + gpt_cond_chunk_len=gpt_cond_chunk_len, max_ref_length=max_ref_len, sound_norm_refs=sound_norm_refs, ) @@ -566,7 +605,7 @@ class Xtts(BaseTTS): if overlap_len > len(wav_chunk): # wav_chunk is smaller than overlap_len, pass on last wav_gen if wav_gen_prev is not None: - wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):] + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :] else: # not expecting will hit here as problem happens on last chunk wav_chunk = wav_gen[-overlap_len:] @@ -576,7 +615,7 @@ class Xtts(BaseTTS): crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) wav_chunk[:overlap_len] += crossfade_wav - + wav_overlap = wav_gen[-overlap_len:] wav_gen_prev = wav_gen return wav_chunk, wav_gen_prev, wav_overlap