From 6fa46d197d5b7ce9620ac2690d4566d55ec580af Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Tue, 24 Oct 2023 10:57:15 +0200 Subject: [PATCH 1/4] Fix get_conditioning_latents when using only ne --- TTS/tts/models/xtts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 87ba3285..33f612f7 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -430,7 +430,7 @@ class Xtts(BaseTTS): ): speaker_embedding = None diffusion_cond_latents = None - if self.args.use_hifigan: + if self.args.use_hifigan or self.args.use_ne_hifigan: speaker_embedding = self.get_speaker_embedding(audio_path) else: diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path) From c1133724a1feff39e66d9f5e93b6f28bb6a8991e Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Tue, 24 Oct 2023 11:00:56 +0200 Subject: [PATCH 2/4] Move lang token add to tokenizer --- TTS/tts/layers/xtts/tokenizer.py | 4 +--- TTS/tts/models/xtts.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 4b9fb9ed..c25d4296 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -483,13 +483,10 @@ class VoiceBpeTokenizer: if lang == "zh-cn": txt = chinese_transliterate(txt) elif lang == "ja": - assert txt[:4] == "[ja]", "Japanese speech should start with the [ja] token." - txt = txt[4:] if self.katsu is None: import cutlet self.katsu = cutlet.Cutlet() txt = japanese_cleaners(txt, self.katsu) - txt = "[ja]" + txt else: raise NotImplementedError() return txt @@ -497,6 +494,7 @@ class VoiceBpeTokenizer: def encode(self, txt, lang): if self.preprocess: txt = self.preprocess_text(txt, lang) + txt = f"[{lang}]{txt}" txt = txt.replace(" ", "[SPACE]") return self.tokenizer.encode(txt).ids diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 33f612f7..1b80035a 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -610,7 +610,7 @@ class Xtts(BaseTTS): decoder="hifigan", **hf_generate_kwargs, ): - text = f"[{language}]{text.strip().lower()}" + text = text.strip().lower() text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) assert ( @@ -722,7 +722,7 @@ class Xtts(BaseTTS): assert hasattr( self, "hifigan_decoder" ), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." - text = f"[{language}]{text.strip().lower()}" + text = text.strip().lower() text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) fake_inputs = self.gpt.compute_embeddings( From d4e08c8d6c84859541af0231b6b5bde4c52da375 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Wed, 25 Oct 2023 09:34:08 +0200 Subject: [PATCH 3/4] Add features to get_conditioning_latents --- TTS/tts/models/xtts.py | 71 ++++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 1b80035a..aacdd6f1 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import torch import torch.nn.functional as F import torchaudio +import librosa from coqpit import Coqpit from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel @@ -34,12 +35,8 @@ def load_audio(audiopath, sr=22050): """ audio, sampling_rate = torchaudio.load(audiopath) - if len(audio.shape) > 1: - if audio.shape[0] < 5: - audio = audio[0] - else: - assert audio.shape[1] < 5 - audio = audio[:, 0] + if audio.shape[0] > 1: + audio = audio.mean(0, keepdim=True) if sampling_rate != sr: resampler = torchaudio.transforms.Resample(sampling_rate, sr) @@ -376,7 +373,7 @@ class Xtts(BaseTTS): return next(self.parameters()).device @torch.inference_mode() - def get_gpt_cond_latents(self, audio_path: str, length: int = 3): + def get_gpt_cond_latents(self, audio, sr, length: int = 3): """Compute the conditioning latents for the GPT model from the given audio. Args: @@ -384,24 +381,21 @@ class Xtts(BaseTTS): length (int): Length of the audio in seconds. Defaults to 3. """ - audio = load_audio(audio_path) - audio = audio[:, : 22050 * length] - mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu()) + audio_22k = torchaudio.functional.resample(audio, sr, 22050) + audio_22k = audio_22k[:, : 22050 * length] + mel = wav_to_mel_cloning(audio_22k, mel_norms=self.mel_stats.cpu()) cond_latent = self.gpt.get_style_emb(mel.to(self.device)) return cond_latent.transpose(1, 2) @torch.inference_mode() - def get_diffusion_cond_latents( - self, - audio_path, - ): + def get_diffusion_cond_latents(self, audio, sr): from math import ceil diffusion_conds = [] CHUNK_SIZE = 102400 - audio = load_audio(audio_path, 24000) - for chunk in range(ceil(audio.shape[1] / CHUNK_SIZE)): - current_sample = audio[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE] + audio_24k = torchaudio.functional.resample(audio, sr, 24000) + for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)): + current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE] current_sample = pad_or_truncate(current_sample, CHUNK_SIZE) cond_mel = wav_to_univnet_mel( current_sample.to(self.device), @@ -414,27 +408,38 @@ class Xtts(BaseTTS): return diffusion_latent @torch.inference_mode() - def get_speaker_embedding(self, audio_path): - audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"]) - speaker_embedding = ( - self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True) - .unsqueeze(-1) - .to(self.device) - ) - return speaker_embedding - + def get_speaker_embedding(self, audio, sr): + audio_16k = torchaudio.functional.resample(audio, sr, 16000) + return self.hifigan_decoder.speaker_encoder.forward( + audio_16k.to(self.device), l2_norm=True + ).unsqueeze(-1).to(self.device) + + @torch.inference_mode() def get_conditioning_latents( self, audio_path, - gpt_cond_len=3, - ): + gpt_cond_len=6, + max_ref_length=10, + librosa_trim_db=None, + sound_norm_refs=False, + ): speaker_embedding = None diffusion_cond_latents = None + + audio, sr = torchaudio.load(audio_path) + audio = audio[:, : sr * max_ref_length].to(self.device) + if audio.shape[0] > 1: + audio = audio.mean(0, keepdim=True) + if sound_norm_refs: + audio = (audio / torch.abs(audio).max()) * 0.75 + if librosa_trim_db is not None: + audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] + if self.args.use_hifigan or self.args.use_ne_hifigan: - speaker_embedding = self.get_speaker_embedding(audio_path) + speaker_embedding = self.get_speaker_embedding(audio, sr) else: - diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path) - gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T] + diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr) + gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T] return gpt_cond_latents, diffusion_cond_latents, speaker_embedding def synthesize(self, text, config, speaker_wav, language, **kwargs): @@ -494,7 +499,7 @@ class Xtts(BaseTTS): repetition_penalty=2.0, top_k=50, top_p=0.85, - gpt_cond_len=4, + gpt_cond_len=6, do_sample=True, # Decoder inference decoder_iterations=100, @@ -531,7 +536,7 @@ class Xtts(BaseTTS): (aka boring) outputs. Defaults to 0.8. gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used - else the first `gpt_cond_len` secs is used. Defaults to 3 seconds. + else the first `gpt_cond_len` secs is used. Defaults to 6 seconds. decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine the output, which should theoretically mean a higher quality output. From 1c988213598c3193e3e167ebe1ff58ecd5d9b221 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Fri, 27 Oct 2023 22:27:18 +0200 Subject: [PATCH 4/4] Remove unused load_audio function --- TTS/tts/models/xtts.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index aacdd6f1..60af2d1e 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -22,30 +22,6 @@ from TTS.utils.io import load_fsspec init_stream_support() -def load_audio(audiopath, sr=22050): - """ - Load an audio file from disk and resample it to the specified sampling rate. - - Args: - audiopath (str): Path to the audio file. - sr (int): Target sampling rate. - - Returns: - Tensor: Audio waveform tensor with shape (1, T), where T is the number of samples. - """ - audio, sampling_rate = torchaudio.load(audiopath) - - if audio.shape[0] > 1: - audio = audio.mean(0, keepdim=True) - - if sampling_rate != sr: - resampler = torchaudio.transforms.Resample(sampling_rate, sr) - audio = resampler(audio) - - audio = audio.clamp_(-1, 1) - return audio.unsqueeze(0) - - def wav_to_mel_cloning( wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu") ):