Implement chunking gpt_cond

This commit is contained in:
Eren G??lge 2023-11-13 13:00:08 +01:00
parent 6f1cba2f81
commit a16360af85
2 changed files with 78 additions and 33 deletions

View File

@ -43,7 +43,12 @@ class XttsConfig(BaseTTSConfig):
Defaults to `16`. Defaults to `16`.
gpt_cond_len (int): gpt_cond_len (int):
Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`. Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`.
gpt_cond_chunk_len (int):
Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the
latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len.
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`.
max_ref_len (int): max_ref_len (int):
Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`. Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`.
@ -95,6 +100,7 @@ class XttsConfig(BaseTTSConfig):
num_gpt_outputs: int = 1 num_gpt_outputs: int = 1
# cloning # cloning
gpt_cond_len: int = 3 gpt_cond_len: int = 12
gpt_cond_chunk_len: int = 4
max_ref_len: int = 10 max_ref_len: int = 10
sound_norm_refs: bool = False sound_norm_refs: bool = False

View File

@ -255,31 +255,49 @@ class Xtts(BaseTTS):
return next(self.parameters()).device return next(self.parameters()).device
@torch.inference_mode() @torch.inference_mode()
def get_gpt_cond_latents(self, audio, sr, length: int = 3): def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
"""Compute the conditioning latents for the GPT model from the given audio. """Compute the conditioning latents for the GPT model from the given audio.
Args: Args:
audio (tensor): audio tensor. audio (tensor): audio tensor.
sr (int): Sample rate of the audio. sr (int): Sample rate of the audio.
length (int): Length of the audio in seconds. Defaults to 3. length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30.
chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio
is being used without chunking. It must be < `length`. Defaults to 6.
""" """
if sr != 22050: if sr != 22050:
audio = torchaudio.functional.resample(audio, sr, 22050) audio = torchaudio.functional.resample(audio, sr, 22050)
if length > 0:
audio = audio[:, : 22050 * length] audio = audio[:, : 22050 * length]
if self.args.gpt_use_perceiver_resampler: if self.args.gpt_use_perceiver_resampler:
n_fft = 2048 style_embs = []
hop_length = 256 for i in range(0, audio.shape[1], 22050 * chunk_length):
win_length = 1024 audio_chunk = audio[:, i : i + 22050 * chunk_length]
mel_chunk = wav_to_mel_cloning(
audio_chunk,
mel_norms=self.mel_stats.cpu(),
n_fft=2048,
hop_length=256,
win_length=1024,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None)
style_embs.append(style_emb)
# mean style embedding
cond_latent = torch.stack(style_embs).mean(dim=0)
else: else:
n_fft = 4096
hop_length = 1024
win_length = 4096
mel = wav_to_mel_cloning( mel = wav_to_mel_cloning(
audio, audio,
mel_norms=self.mel_stats.cpu(), mel_norms=self.mel_stats.cpu(),
n_fft=n_fft, n_fft=4096,
hop_length=hop_length, hop_length=1024,
win_length=win_length, win_length=4096,
power=2, power=2,
normalized=False, normalized=False,
sample_rate=22050, sample_rate=22050,
@ -323,12 +341,24 @@ class Xtts(BaseTTS):
def get_conditioning_latents( def get_conditioning_latents(
self, self,
audio_path, audio_path,
max_ref_length=30,
gpt_cond_len=6, gpt_cond_len=6,
max_ref_length=10, gpt_cond_chunk_len=6,
librosa_trim_db=None, librosa_trim_db=None,
sound_norm_refs=False, sound_norm_refs=False,
load_sr=24000, load_sr=22050,
): ):
"""Get the conditioning latents for the GPT model from the given audio.
Args:
audio_path (str or List[str]): Path to reference audio file(s).
max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30.
gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6.
gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6.
librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None.
sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False.
load_sr (int, optional): Sample rate to load the audio. Defaults to 24000.
"""
# deal with multiples references # deal with multiples references
if not isinstance(audio_path, list): if not isinstance(audio_path, list):
audio_paths = [audio_path] audio_paths = [audio_path]
@ -349,14 +379,17 @@ class Xtts(BaseTTS):
if librosa_trim_db is not None: if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
# compute latents for the decoder
speaker_embedding = self.get_speaker_embedding(audio, load_sr) speaker_embedding = self.get_speaker_embedding(audio, load_sr)
speaker_embeddings.append(speaker_embedding) speaker_embeddings.append(speaker_embedding)
audios.append(audio) audios.append(audio)
# use a merge of all references for gpt cond latents # merge all the audios and compute the latents for the gpt
full_audio = torch.cat(audios, dim=-1) full_audio = torch.cat(audios, dim=-1)
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T] gpt_cond_latents = self.get_gpt_cond_latents(
full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len
) # [1, 1024, T]
if speaker_embeddings: if speaker_embeddings:
speaker_embedding = torch.stack(speaker_embeddings) speaker_embedding = torch.stack(speaker_embeddings)
@ -397,6 +430,7 @@ class Xtts(BaseTTS):
"top_k": config.top_k, "top_k": config.top_k,
"top_p": config.top_p, "top_p": config.top_p,
"gpt_cond_len": config.gpt_cond_len, "gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len, "max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs, "sound_norm_refs": config.sound_norm_refs,
} }
@ -417,7 +451,8 @@ class Xtts(BaseTTS):
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
# Cloning # Cloning
gpt_cond_len=6, gpt_cond_len=30,
gpt_cond_chunk_len=6,
max_ref_len=10, max_ref_len=10,
sound_norm_refs=False, sound_norm_refs=False,
**hf_generate_kwargs, **hf_generate_kwargs,
@ -448,7 +483,10 @@ class Xtts(BaseTTS):
(aka boring) outputs. Defaults to 0.8. (aka boring) outputs. Defaults to 0.8.
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds. else the first `gpt_cond_len` secs is used. Defaults to 30 seconds.
gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`.
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds.
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
@ -461,6 +499,7 @@ class Xtts(BaseTTS):
(gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents( (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
audio_path=ref_audio_path, audio_path=ref_audio_path,
gpt_cond_len=gpt_cond_len, gpt_cond_len=gpt_cond_len,
gpt_cond_chunk_len=gpt_cond_chunk_len,
max_ref_length=max_ref_len, max_ref_length=max_ref_len,
sound_norm_refs=sound_norm_refs, sound_norm_refs=sound_norm_refs,
) )