mirror of https://github.com/coqui-ai/TTS.git
Implement chunking gpt_cond
This commit is contained in:
parent
6f1cba2f81
commit
a16360af85
|
@ -43,7 +43,12 @@ class XttsConfig(BaseTTSConfig):
|
|||
Defaults to `16`.
|
||||
|
||||
gpt_cond_len (int):
|
||||
Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`.
|
||||
Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`.
|
||||
|
||||
gpt_cond_chunk_len (int):
|
||||
Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the
|
||||
latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len.
|
||||
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`.
|
||||
|
||||
max_ref_len (int):
|
||||
Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`.
|
||||
|
@ -95,6 +100,7 @@ class XttsConfig(BaseTTSConfig):
|
|||
num_gpt_outputs: int = 1
|
||||
|
||||
# cloning
|
||||
gpt_cond_len: int = 3
|
||||
gpt_cond_len: int = 12
|
||||
gpt_cond_chunk_len: int = 4
|
||||
max_ref_len: int = 10
|
||||
sound_norm_refs: bool = False
|
||||
|
|
|
@ -255,39 +255,57 @@ class Xtts(BaseTTS):
|
|||
return next(self.parameters()).device
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_gpt_cond_latents(self, audio, sr, length: int = 3):
|
||||
def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
|
||||
"""Compute the conditioning latents for the GPT model from the given audio.
|
||||
|
||||
Args:
|
||||
audio (tensor): audio tensor.
|
||||
sr (int): Sample rate of the audio.
|
||||
length (int): Length of the audio in seconds. Defaults to 3.
|
||||
length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30.
|
||||
chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio
|
||||
is being used without chunking. It must be < `length`. Defaults to 6.
|
||||
"""
|
||||
if sr != 22050:
|
||||
audio = torchaudio.functional.resample(audio, sr, 22050)
|
||||
audio = audio[:, : 22050 * length]
|
||||
if length > 0:
|
||||
audio = audio[:, : 22050 * length]
|
||||
if self.args.gpt_use_perceiver_resampler:
|
||||
n_fft = 2048
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
style_embs = []
|
||||
for i in range(0, audio.shape[1], 22050 * chunk_length):
|
||||
audio_chunk = audio[:, i : i + 22050 * chunk_length]
|
||||
mel_chunk = wav_to_mel_cloning(
|
||||
audio_chunk,
|
||||
mel_norms=self.mel_stats.cpu(),
|
||||
n_fft=2048,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
)
|
||||
style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None)
|
||||
style_embs.append(style_emb)
|
||||
|
||||
# mean style embedding
|
||||
cond_latent = torch.stack(style_embs).mean(dim=0)
|
||||
else:
|
||||
n_fft = 4096
|
||||
hop_length = 1024
|
||||
win_length = 4096
|
||||
mel = wav_to_mel_cloning(
|
||||
audio,
|
||||
mel_norms=self.mel_stats.cpu(),
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
)
|
||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||
mel = wav_to_mel_cloning(
|
||||
audio,
|
||||
mel_norms=self.mel_stats.cpu(),
|
||||
n_fft=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
)
|
||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
@ -323,12 +341,24 @@ class Xtts(BaseTTS):
|
|||
def get_conditioning_latents(
|
||||
self,
|
||||
audio_path,
|
||||
max_ref_length=30,
|
||||
gpt_cond_len=6,
|
||||
max_ref_length=10,
|
||||
gpt_cond_chunk_len=6,
|
||||
librosa_trim_db=None,
|
||||
sound_norm_refs=False,
|
||||
load_sr=24000,
|
||||
load_sr=22050,
|
||||
):
|
||||
"""Get the conditioning latents for the GPT model from the given audio.
|
||||
|
||||
Args:
|
||||
audio_path (str or List[str]): Path to reference audio file(s).
|
||||
max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30.
|
||||
gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6.
|
||||
gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6.
|
||||
librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None.
|
||||
sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False.
|
||||
load_sr (int, optional): Sample rate to load the audio. Defaults to 24000.
|
||||
"""
|
||||
# deal with multiples references
|
||||
if not isinstance(audio_path, list):
|
||||
audio_paths = [audio_path]
|
||||
|
@ -349,14 +379,17 @@ class Xtts(BaseTTS):
|
|||
if librosa_trim_db is not None:
|
||||
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
||||
|
||||
# compute latents for the decoder
|
||||
speaker_embedding = self.get_speaker_embedding(audio, load_sr)
|
||||
speaker_embeddings.append(speaker_embedding)
|
||||
|
||||
audios.append(audio)
|
||||
|
||||
# use a merge of all references for gpt cond latents
|
||||
# merge all the audios and compute the latents for the gpt
|
||||
full_audio = torch.cat(audios, dim=-1)
|
||||
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T]
|
||||
gpt_cond_latents = self.get_gpt_cond_latents(
|
||||
full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len
|
||||
) # [1, 1024, T]
|
||||
|
||||
if speaker_embeddings:
|
||||
speaker_embedding = torch.stack(speaker_embeddings)
|
||||
|
@ -397,6 +430,7 @@ class Xtts(BaseTTS):
|
|||
"top_k": config.top_k,
|
||||
"top_p": config.top_p,
|
||||
"gpt_cond_len": config.gpt_cond_len,
|
||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||
"max_ref_len": config.max_ref_len,
|
||||
"sound_norm_refs": config.sound_norm_refs,
|
||||
}
|
||||
|
@ -417,7 +451,8 @@ class Xtts(BaseTTS):
|
|||
top_p=0.85,
|
||||
do_sample=True,
|
||||
# Cloning
|
||||
gpt_cond_len=6,
|
||||
gpt_cond_len=30,
|
||||
gpt_cond_chunk_len=6,
|
||||
max_ref_len=10,
|
||||
sound_norm_refs=False,
|
||||
**hf_generate_kwargs,
|
||||
|
@ -448,7 +483,10 @@ class Xtts(BaseTTS):
|
|||
(aka boring) outputs. Defaults to 0.8.
|
||||
|
||||
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
|
||||
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
|
||||
else the first `gpt_cond_len` secs is used. Defaults to 30 seconds.
|
||||
|
||||
gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`.
|
||||
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds.
|
||||
|
||||
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
|
||||
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
|
||||
|
@ -461,6 +499,7 @@ class Xtts(BaseTTS):
|
|||
(gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
|
||||
audio_path=ref_audio_path,
|
||||
gpt_cond_len=gpt_cond_len,
|
||||
gpt_cond_chunk_len=gpt_cond_chunk_len,
|
||||
max_ref_length=max_ref_len,
|
||||
sound_norm_refs=sound_norm_refs,
|
||||
)
|
||||
|
@ -566,7 +605,7 @@ class Xtts(BaseTTS):
|
|||
if overlap_len > len(wav_chunk):
|
||||
# wav_chunk is smaller than overlap_len, pass on last wav_gen
|
||||
if wav_gen_prev is not None:
|
||||
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):]
|
||||
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :]
|
||||
else:
|
||||
# not expecting will hit here as problem happens on last chunk
|
||||
wav_chunk = wav_gen[-overlap_len:]
|
||||
|
|
Loading…
Reference in New Issue