Merge pull request #3207 from coqui-ai/update_xtts_cloning

Update XTTS cloning
This commit is contained in:
Eren Gölge 2023-11-13 14:32:43 +01:00 committed by GitHub
commit f32a465711
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 115 additions and 51 deletions

View File

@ -82,7 +82,6 @@ class CS_API:
}, },
} }
SUPPORTED_LANGUAGES = ["en", "es", "de", "fr", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja"] SUPPORTED_LANGUAGES = ["en", "es", "de", "fr", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja"]
def __init__(self, api_token=None, model="XTTS"): def __init__(self, api_token=None, model="XTTS"):
@ -308,7 +307,11 @@ if __name__ == "__main__":
print(api.list_speakers_as_tts_models()) print(api.list_speakers_as_tts_models())
ts = time.time() ts = time.time()
wav, sr = api.tts("It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name) wav, sr = api.tts(
"It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name
)
print(f" [i] XTTS took {time.time() - ts:.2f}s") print(f" [i] XTTS took {time.time() - ts:.2f}s")
filepath = api.tts_to_file(text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav") filepath = api.tts_to_file(
text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav"
)

View File

@ -43,7 +43,12 @@ class XttsConfig(BaseTTSConfig):
Defaults to `16`. Defaults to `16`.
gpt_cond_len (int): gpt_cond_len (int):
Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`. Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`.
gpt_cond_chunk_len (int):
Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the
latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len.
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`.
max_ref_len (int): max_ref_len (int):
Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`. Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`.
@ -95,6 +100,7 @@ class XttsConfig(BaseTTSConfig):
num_gpt_outputs: int = 1 num_gpt_outputs: int = 1
# cloning # cloning
gpt_cond_len: int = 3 gpt_cond_len: int = 12
gpt_cond_chunk_len: int = 4
max_ref_len: int = 10 max_ref_len: int = 10
sound_norm_refs: bool = False sound_norm_refs: bool = False

View File

@ -562,15 +562,21 @@ class DPM_Solver:
if order == 3: if order == 3:
K = steps // 3 + 1 K = steps // 3 + 1
if steps % 3 == 0: if steps % 3 == 0:
orders = [3,] * ( orders = [
3,
] * (
K - 2 K - 2
) + [2, 1] ) + [2, 1]
elif steps % 3 == 1: elif steps % 3 == 1:
orders = [3,] * ( orders = [
3,
] * (
K - 1 K - 1
) + [1] ) + [1]
else: else:
orders = [3,] * ( orders = [
3,
] * (
K - 1 K - 1
) + [2] ) + [2]
elif order == 2: elif order == 2:
@ -581,7 +587,9 @@ class DPM_Solver:
] * K ] * K
else: else:
K = steps // 2 + 1 K = steps // 2 + 1
orders = [2,] * ( orders = [
2,
] * (
K - 1 K - 1
) + [1] ) + [1]
elif order == 1: elif order == 1:
@ -1440,7 +1448,10 @@ class DPM_Solver:
model_prev_list[-1] = self.model_fn(x, t) model_prev_list[-1] = self.model_fn(x, t)
elif method in ["singlestep", "singlestep_fixed"]: elif method in ["singlestep", "singlestep_fixed"]:
if method == "singlestep": if method == "singlestep":
(timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver( (
timesteps_outer,
orders,
) = self.get_orders_and_timesteps_for_singlestep_solver(
steps=steps, steps=steps,
order=order, order=order,
skip_type=skip_type, skip_type=skip_type,
@ -1548,4 +1559,4 @@ def expand_dims(v, dims):
Returns: Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
""" """
return v[(...,) + (None,) * (dims - 1)] return v[(...,) + (None,) * (dims - 1)]

View File

@ -1,6 +1,7 @@
import json import json
import os import os
import re import re
from functools import cached_property
import pypinyin import pypinyin
import torch import torch
@ -8,7 +9,6 @@ from hangul_romanize import Transliter
from hangul_romanize.rule import academic from hangul_romanize.rule import academic
from num2words import num2words from num2words import num2words
from tokenizers import Tokenizer from tokenizers import Tokenizer
from functools import cached_property
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
@ -560,19 +560,22 @@ class VoiceBpeTokenizer:
@cached_property @cached_property
def katsu(self): def katsu(self):
import cutlet import cutlet
return cutlet.Cutlet() return cutlet.Cutlet()
def check_input_length(self, txt, lang): def check_input_length(self, txt, lang):
limit = self.char_limits.get(lang, 250) limit = self.char_limits.get(lang, 250)
if len(txt) > limit: if len(txt) > limit:
print(f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio.") print(
f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
)
def preprocess_text(self, txt, lang): def preprocess_text(self, txt, lang):
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}: if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}:
txt = multilingual_cleaners(txt, lang) txt = multilingual_cleaners(txt, lang)
if lang in {"zh", "zh-cn"}: if lang in {"zh", "zh-cn"}:
txt = chinese_transliterate(txt) txt = chinese_transliterate(txt)
elif lang == "ja": elif lang == "ja":
txt = japanese_cleaners(txt, self.katsu) txt = japanese_cleaners(txt, self.katsu)
elif lang == "ko": elif lang == "ko":
txt = korean_cleaners(txt) txt = korean_cleaners(txt)

View File

@ -5,6 +5,7 @@ import sys
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
from TTS.tts.models.xtts import load_audio from TTS.tts.models.xtts import load_audio
torch.set_num_threads(1) torch.set_num_threads(1)

View File

@ -255,39 +255,57 @@ class Xtts(BaseTTS):
return next(self.parameters()).device return next(self.parameters()).device
@torch.inference_mode() @torch.inference_mode()
def get_gpt_cond_latents(self, audio, sr, length: int = 3): def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
"""Compute the conditioning latents for the GPT model from the given audio. """Compute the conditioning latents for the GPT model from the given audio.
Args: Args:
audio (tensor): audio tensor. audio (tensor): audio tensor.
sr (int): Sample rate of the audio. sr (int): Sample rate of the audio.
length (int): Length of the audio in seconds. Defaults to 3. length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30.
chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio
is being used without chunking. It must be < `length`. Defaults to 6.
""" """
if sr != 22050: if sr != 22050:
audio = torchaudio.functional.resample(audio, sr, 22050) audio = torchaudio.functional.resample(audio, sr, 22050)
audio = audio[:, : 22050 * length] if length > 0:
audio = audio[:, : 22050 * length]
if self.args.gpt_use_perceiver_resampler: if self.args.gpt_use_perceiver_resampler:
n_fft = 2048 style_embs = []
hop_length = 256 for i in range(0, audio.shape[1], 22050 * chunk_length):
win_length = 1024 audio_chunk = audio[:, i : i + 22050 * chunk_length]
mel_chunk = wav_to_mel_cloning(
audio_chunk,
mel_norms=self.mel_stats.cpu(),
n_fft=2048,
hop_length=256,
win_length=1024,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None)
style_embs.append(style_emb)
# mean style embedding
cond_latent = torch.stack(style_embs).mean(dim=0)
else: else:
n_fft = 4096 mel = wav_to_mel_cloning(
hop_length = 1024 audio,
win_length = 4096 mel_norms=self.mel_stats.cpu(),
mel = wav_to_mel_cloning( n_fft=4096,
audio, hop_length=1024,
mel_norms=self.mel_stats.cpu(), win_length=4096,
n_fft=n_fft, power=2,
hop_length=hop_length, normalized=False,
win_length=win_length, sample_rate=22050,
power=2, f_min=0,
normalized=False, f_max=8000,
sample_rate=22050, n_mels=80,
f_min=0, )
f_max=8000, cond_latent = self.gpt.get_style_emb(mel.to(self.device))
n_mels=80,
)
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2) return cond_latent.transpose(1, 2)
@torch.inference_mode() @torch.inference_mode()
@ -323,12 +341,24 @@ class Xtts(BaseTTS):
def get_conditioning_latents( def get_conditioning_latents(
self, self,
audio_path, audio_path,
max_ref_length=30,
gpt_cond_len=6, gpt_cond_len=6,
max_ref_length=10, gpt_cond_chunk_len=6,
librosa_trim_db=None, librosa_trim_db=None,
sound_norm_refs=False, sound_norm_refs=False,
load_sr=24000, load_sr=22050,
): ):
"""Get the conditioning latents for the GPT model from the given audio.
Args:
audio_path (str or List[str]): Path to reference audio file(s).
max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30.
gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6.
gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6.
librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None.
sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False.
load_sr (int, optional): Sample rate to load the audio. Defaults to 24000.
"""
# deal with multiples references # deal with multiples references
if not isinstance(audio_path, list): if not isinstance(audio_path, list):
audio_paths = [audio_path] audio_paths = [audio_path]
@ -339,24 +369,24 @@ class Xtts(BaseTTS):
audios = [] audios = []
speaker_embedding = None speaker_embedding = None
for file_path in audio_paths: for file_path in audio_paths:
# load the audio in 24khz to avoid issued with multiple sr references
audio = load_audio(file_path, load_sr) audio = load_audio(file_path, load_sr)
audio = audio[:, : load_sr * max_ref_length].to(self.device) audio = audio[:, : load_sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs: if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75 audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None: if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
# compute latents for the decoder
speaker_embedding = self.get_speaker_embedding(audio, load_sr) speaker_embedding = self.get_speaker_embedding(audio, load_sr)
speaker_embeddings.append(speaker_embedding) speaker_embeddings.append(speaker_embedding)
audios.append(audio) audios.append(audio)
# use a merge of all references for gpt cond latents # merge all the audios and compute the latents for the gpt
full_audio = torch.cat(audios, dim=-1) full_audio = torch.cat(audios, dim=-1)
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T] gpt_cond_latents = self.get_gpt_cond_latents(
full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len
) # [1, 1024, T]
if speaker_embeddings: if speaker_embeddings:
speaker_embedding = torch.stack(speaker_embeddings) speaker_embedding = torch.stack(speaker_embeddings)
@ -397,6 +427,7 @@ class Xtts(BaseTTS):
"top_k": config.top_k, "top_k": config.top_k,
"top_p": config.top_p, "top_p": config.top_p,
"gpt_cond_len": config.gpt_cond_len, "gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len, "max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs, "sound_norm_refs": config.sound_norm_refs,
} }
@ -417,7 +448,8 @@ class Xtts(BaseTTS):
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
# Cloning # Cloning
gpt_cond_len=6, gpt_cond_len=30,
gpt_cond_chunk_len=6,
max_ref_len=10, max_ref_len=10,
sound_norm_refs=False, sound_norm_refs=False,
**hf_generate_kwargs, **hf_generate_kwargs,
@ -448,7 +480,10 @@ class Xtts(BaseTTS):
(aka boring) outputs. Defaults to 0.8. (aka boring) outputs. Defaults to 0.8.
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds. else the first `gpt_cond_len` secs is used. Defaults to 30 seconds.
gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`.
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds.
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
@ -461,6 +496,7 @@ class Xtts(BaseTTS):
(gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents( (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
audio_path=ref_audio_path, audio_path=ref_audio_path,
gpt_cond_len=gpt_cond_len, gpt_cond_len=gpt_cond_len,
gpt_cond_chunk_len=gpt_cond_chunk_len,
max_ref_length=max_ref_len, max_ref_length=max_ref_len,
sound_norm_refs=sound_norm_refs, sound_norm_refs=sound_norm_refs,
) )
@ -566,7 +602,7 @@ class Xtts(BaseTTS):
if overlap_len > len(wav_chunk): if overlap_len > len(wav_chunk):
# wav_chunk is smaller than overlap_len, pass on last wav_gen # wav_chunk is smaller than overlap_len, pass on last wav_gen
if wav_gen_prev is not None: if wav_gen_prev is not None:
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):] wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :]
else: else:
# not expecting will hit here as problem happens on last chunk # not expecting will hit here as problem happens on last chunk
wav_chunk = wav_gen[-overlap_len:] wav_chunk = wav_gen[-overlap_len:]
@ -576,7 +612,7 @@ class Xtts(BaseTTS):
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
wav_chunk[:overlap_len] += crossfade_wav wav_chunk[:overlap_len] += crossfade_wav
wav_overlap = wav_gen[-overlap_len:] wav_overlap = wav_gen[-overlap_len:]
wav_gen_prev = wav_gen wav_gen_prev = wav_gen
return wav_chunk, wav_gen_prev, wav_overlap return wav_chunk, wav_gen_prev, wav_overlap

View File

@ -60,7 +60,9 @@ XTTS_CHECKPOINT = None # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_s
# Training sentences generations # Training sentences generations
SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences SPEAKER_REFERENCE = [
"tests/data/ljspeech/wavs/LJ001-0002.wav"
] # speaker reference to be used in training test sentences
LANGUAGE = config_dataset.language LANGUAGE = config_dataset.language

View File

@ -58,7 +58,9 @@ XTTS_CHECKPOINT = None # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_s
# Training sentences generations # Training sentences generations
SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences SPEAKER_REFERENCE = [
"tests/data/ljspeech/wavs/LJ001-0002.wav"
] # speaker reference to be used in training test sentences
LANGUAGE = config_dataset.language LANGUAGE = config_dataset.language