Merge pull request #3149 from coqui-ai/fixup_xtts_v2

Bug fixes and add support for multiples speaker references on XTTS inference
This commit is contained in:
Eren Gölge 2023-11-07 10:36:20 +01:00 committed by GitHub
commit 5e992d8704
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 129 additions and 199 deletions

View File

@ -10,34 +10,22 @@
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5" "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
], ],
"model_hash": "ae9e4b39e095fd5728fe7f7931eccoqui",
"default_vocoder": null, "default_vocoder": null,
"commit": "480a6cdf7", "commit": "480a6cdf7",
"license": "CPML", "license": "CPML",
"contact": "info@coqui.ai", "contact": "info@coqui.ai",
"tos_required": true "tos_required": true
}, },
"xtts_v1": {
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/vocab.json"
],
"default_vocoder": null,
"commit": "e5140314",
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": true
},
"xtts_v1.1": { "xtts_v1.1": {
"description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.", "description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.",
"hf_url": [ "hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/config.json", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/hash.md5" "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/hash.md5"
], ],
"model_hash": "ae9e4b39e095fd5728fe7f7931ec66ad", "model_hash": "7c62beaf58d39b729de287330dc254e7b515677416839b649a50e7cf74c3df59",
"default_vocoder": null, "default_vocoder": null,
"commit": "82910a63", "commit": "82910a63",
"license": "CPML", "license": "CPML",

View File

@ -2,13 +2,10 @@ import os
import random import random
import sys import sys
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
import torchaudio from TTS.tts.models.xtts import load_audio
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
torch.set_num_threads(1) torch.set_num_threads(1)
@ -50,31 +47,6 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
return rel_clip, rel_clip.shape[-1], cond_idxs return rel_clip, rel_clip.shape[-1], cond_idxs
def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio_sox_load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio_soundfile_load(audiopath)
# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)
if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
return audio
class XTTSDataset(torch.utils.data.Dataset): class XTTSDataset(torch.utils.data.Dataset):
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False): def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
self.config = config self.config = config

View File

@ -238,7 +238,6 @@ class GPTTrainer(BaseTTS):
s_info["speaker_wav"], s_info["speaker_wav"],
s_info["language"], s_info["language"],
gpt_cond_len=3, gpt_cond_len=3,
decoder="ne_hifigan",
)["wav"] )["wav"]
test_audios["{}-audio".format(idx)] = wav test_audios["{}-audio".format(idx)] = wav

View File

@ -67,6 +67,31 @@ def wav_to_mel_cloning(
return mel return mel
def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio.backend.sox_io_backend.load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio.backend.soundfile_backend.load(audiopath)
# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)
if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
return audio
def pad_or_truncate(t, length): def pad_or_truncate(t, length):
""" """
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it. Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
@ -86,78 +111,6 @@ def pad_or_truncate(t, length):
return tp return tp
def load_discrete_vocoder_diffuser(
trained_diffusion_steps=4000,
desired_diffusion_steps=200,
cond_free=True,
cond_free_k=1,
sampler="ddim",
):
"""
Load a GaussianDiffusion instance configured for use as a decoder.
Args:
trained_diffusion_steps (int): The number of diffusion steps used during training.
desired_diffusion_steps (int): The number of diffusion steps to use during inference.
cond_free (bool): Whether to use a conditioning-free model.
cond_free_k (int): The number of samples to use for conditioning-free models.
sampler (str): The name of the sampler to use.
Returns:
A SpacedDiffusion instance configured with the given parameters.
"""
return SpacedDiffusion(
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
model_mean_type="epsilon",
model_var_type="learned_range",
loss_type="mse",
betas=get_named_beta_schedule("linear", trained_diffusion_steps),
conditioning_free=cond_free,
conditioning_free_k=cond_free_k,
sampler=sampler,
)
def do_spectrogram_diffusion(
diffusion_model,
diffuser,
latents,
conditioning_latents,
temperature=1,
):
"""
Generate a mel-spectrogram using a diffusion model and a diffuser.
Args:
diffusion_model (nn.Module): A diffusion model that converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
diffuser (Diffuser): A diffuser that generates a mel-spectrogram from noise.
latents (torch.Tensor): A tensor of shape (batch_size, seq_len, code_size) containing the input spectrogram codes.
conditioning_latents (torch.Tensor): A tensor of shape (batch_size, code_size) containing the conditioning codes.
temperature (float, optional): The temperature of the noise used by the diffuser. Defaults to 1.
Returns:
torch.Tensor: A tensor of shape (batch_size, mel_channels, mel_seq_len) containing the generated mel-spectrogram.
"""
with torch.no_grad():
output_seq_len = (
latents.shape[1] * 4 * 24000 // 22050
) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(
latents, conditioning_latents, output_seq_len, False
)
noise = torch.randn(output_shape, device=latents.device) * temperature
mel = diffuser.sample_loop(
diffusion_model,
output_shape,
noise=noise,
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
progress=False,
)
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
@dataclass @dataclass
class XttsAudioConfig(Coqpit): class XttsAudioConfig(Coqpit):
""" """
@ -336,7 +289,7 @@ class Xtts(BaseTTS):
"""Compute the conditioning latents for the GPT model from the given audio. """Compute the conditioning latents for the GPT model from the given audio.
Args: Args:
audio_path (str): Path to the audio file. audio (tensor): audio tensor.
sr (int): Sample rate of the audio. sr (int): Sample rate of the audio.
length (int): Length of the audio in seconds. Defaults to 3. length (int): Length of the audio in seconds. Defaults to 3.
""" """
@ -404,20 +357,41 @@ class Xtts(BaseTTS):
max_ref_length=10, max_ref_length=10,
librosa_trim_db=None, librosa_trim_db=None,
sound_norm_refs=False, sound_norm_refs=False,
load_sr=24000,
): ):
# deal with multiples references
if not isinstance(audio_path, list):
audio_paths = [audio_path]
else:
audio_paths = audio_path
speaker_embeddings = []
audios = []
speaker_embedding = None speaker_embedding = None
for file_path in audio_paths:
# load the audio in 24khz to avoid issued with multiple sr references
audio = load_audio(file_path, load_sr)
audio = audio[:, : load_sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
audio, sr = torchaudio.load(audio_path) speaker_embedding = self.get_speaker_embedding(audio, load_sr)
audio = audio[:, : sr * max_ref_length].to(self.device) speaker_embeddings.append(speaker_embedding)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True) audios.append(audio)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75 # use a merge of all references for gpt cond latents
if librosa_trim_db is not None: full_audio = torch.cat(audios, dim=-1)
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T]
if speaker_embeddings:
speaker_embedding = torch.stack(speaker_embeddings)
speaker_embedding = speaker_embedding.mean(dim=0)
speaker_embedding = self.get_speaker_embedding(audio, sr)
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
return gpt_cond_latents, speaker_embedding return gpt_cond_latents, speaker_embedding
def synthesize(self, text, config, speaker_wav, language, **kwargs): def synthesize(self, text, config, speaker_wav, language, **kwargs):
@ -426,7 +400,7 @@ class Xtts(BaseTTS):
Args: Args:
text (str): Input text. text (str): Input text.
config (XttsConfig): Config with inference parameters. config (XttsConfig): Config with inference parameters.
speaker_wav (str): Path to the speaker audio file for cloning. speaker_wav (list): List of paths to the speaker audio files to be used for cloning.
language (str): Language ID of the speaker. language (str): Language ID of the speaker.
**kwargs: Inference settings. See `inference()`. **kwargs: Inference settings. See `inference()`.
@ -436,11 +410,6 @@ class Xtts(BaseTTS):
as latents used at inference. as latents used at inference.
""" """
# Make the synthesizer happy 🥳
if isinstance(speaker_wav, list):
speaker_wav = speaker_wav[0]
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs) return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)
def inference_with_config(self, text, config, ref_audio_path, language, **kwargs): def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
@ -522,27 +491,6 @@ class Xtts(BaseTTS):
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds. else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
Generally a value above 250 is not noticeably better, however. Defaults to 100.
cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion
performs two forward passes for each diffusion step: one with the outputs of the autoregressive model
and one with no conditioning priors. The output of the two is blended according to the cond_free_k
value below. Conditioning-free diffusion is the real deal, and dramatically improves realism.
Defaults to True.
cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the
conditioning-present signal. [0,inf]. As cond_free_k increases, the output becomes dominated by the
conditioning-free signal. Defaults to 2.0.
diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1].
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to 1.0.
decoder: (str) Selects the decoder to use between ("hifigan", "diffusion")
Defaults to hifigan
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
here: https://huggingface.co/docs/transformers/internal/generation_utils here: https://huggingface.co/docs/transformers/internal/generation_utils
@ -569,12 +517,6 @@ class Xtts(BaseTTS):
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
do_sample=do_sample, do_sample=do_sample,
decoder_iterations=decoder_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
diffusion_temperature=diffusion_temperature,
decoder_sampler=decoder_sampler,
decoder=decoder,
**hf_generate_kwargs, **hf_generate_kwargs,
) )
@ -592,13 +534,6 @@ class Xtts(BaseTTS):
top_k=50, top_k=50,
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
# Decoder inference
decoder_iterations=100,
cond_free=True,
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
decoder="hifigan",
num_beams=1, num_beams=1,
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
@ -693,8 +628,6 @@ class Xtts(BaseTTS):
top_k=50, top_k=50,
top_p=0.85, top_p=0.85,
do_sample=True, do_sample=True,
# Decoder inference
decoder="hifigan",
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
text = text.strip().lower() text = text.strip().lower()

View File

@ -39,6 +39,7 @@ You can also mail us at info@coqui.ai.
### Inference ### Inference
#### 🐸TTS API #### 🐸TTS API
##### Single reference
```python ```python
from TTS.api import TTS from TTS.api import TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True) tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
@ -46,12 +47,25 @@ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
# generate speech by cloning a voice using default settings # generate speech by cloning a voice using default settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
file_path="output.wav", file_path="output.wav",
speaker_wav="/path/to/target/speaker.wav", speaker_wav=["/path/to/target/speaker.wav"],
language="en")
```
##### Multiple references
```python
from TTS.api import TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
# generate speech by cloning a voice using default settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
file_path="output.wav",
speaker_wav=["/path/to/target/speaker.wav", "/path/to/target/speaker_2.wav", "/path/to/target/speaker_3.wav"],
language="en") language="en")
``` ```
#### 🐸TTS Command line #### 🐸TTS Command line
##### Single reference
```console ```console
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \ tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
--text "Bugün okula gitmek istemiyorum." \ --text "Bugün okula gitmek istemiyorum." \
@ -60,6 +74,25 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
--use_cuda true --use_cuda true
``` ```
##### Multiple references
```console
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
--text "Bugün okula gitmek istemiyorum." \
--speaker_wav /path/to/target/speaker.wav /path/to/target/speaker_2.wav /path/to/target/speaker_3.wav \
--language_idx tr \
--use_cuda true
```
or for all wav files in a directory you can use:
```console
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
--text "Bugün okula gitmek istemiyorum." \
--speaker_wav /path/to/target/*.wav \
--language_idx tr \
--use_cuda true
```
#### model directly #### model directly
If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first. If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
@ -83,7 +116,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
model.cuda() model.cuda()
print("Computing speaker latents...") print("Computing speaker latents...")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav") gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
print("Inference...") print("Inference...")
out = model.inference( out = model.inference(
@ -120,7 +153,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
model.cuda() model.cuda()
print("Computing speaker latents...") print("Computing speaker latents...")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav") gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
print("Inference...") print("Inference...")
t0 = time.time() t0 = time.time()
@ -177,7 +210,7 @@ model.load_checkpoint(config, checkpoint_path=XTTS_CHECKPOINT, vocab_path=TOKENI
model.cuda() model.cuda()
print("Computing speaker latents...") print("Computing speaker latents...")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=SPEAKER_REFERENCE) gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=[SPEAKER_REFERENCE])
print("Inference...") print("Inference...")
out = model.inference( out = model.inference(

View File

@ -41,8 +41,8 @@ os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
# DVAE files # DVAE files
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/dvae.pth" DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth" MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/mel_stats.pth"
# Set the path to the downloaded files # Set the path to the downloaded files
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1]) DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1])
@ -55,8 +55,8 @@ if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
# Download XTTS v1.1 checkpoint if needed # Download XTTS v1.1 checkpoint if needed
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json" TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/vocab.json"
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth" XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/model.pth"
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
@ -71,9 +71,9 @@ if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
# Training sentences generations # Training sentences generations
SPEAKER_REFERENCE = ( SPEAKER_REFERENCE = [
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences "./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
) ]
LANGUAGE = config_dataset.language LANGUAGE = config_dataset.language

View File

@ -40,6 +40,7 @@ CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
# ToDo: update DVAE checkpoint
# DVAE files # DVAE files
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/dvae.pth" DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth" MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth"
@ -53,15 +54,14 @@ if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
print(" > Downloading DVAE files!") print(" > Downloading DVAE files!")
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
# ToDo: Update links for XTTS v2.0
# Download XTTS v2.0 checkpoint if needed # Download XTTS v2.0 checkpoint if needed
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/vocab.json" TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/model.pth" XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file
# download XTTS v2.0 files if needed # download XTTS v2.0 files if needed
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
@ -72,9 +72,9 @@ if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
# Training sentences generations # Training sentences generations
SPEAKER_REFERENCE = ( SPEAKER_REFERENCE = [
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences "./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
) ]
LANGUAGE = config_dataset.language LANGUAGE = config_dataset.language
@ -90,9 +90,9 @@ def main():
dvae_checkpoint=DVAE_CHECKPOINT, dvae_checkpoint=DVAE_CHECKPOINT,
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
tokenizer_file=TOKENIZER_FILE, tokenizer_file=TOKENIZER_FILE,
gpt_num_audio_tokens=8194, gpt_num_audio_tokens=1024,
gpt_start_audio_token=8192, gpt_start_audio_token=1025,
gpt_stop_audio_token=8193, gpt_stop_audio_token=1026,
gpt_use_masking_gt_prompt_approach=True, gpt_use_masking_gt_prompt_approach=True,
gpt_use_perceiver_resampler=True, gpt_use_perceiver_resampler=True,
) )

View File

@ -60,7 +60,7 @@ XTTS_CHECKPOINT = None # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_s
# Training sentences generations # Training sentences generations
SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences
LANGUAGE = config_dataset.language LANGUAGE = config_dataset.language

View File

@ -58,7 +58,7 @@ XTTS_CHECKPOINT = None # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_s
# Training sentences generations # Training sentences generations
SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences
LANGUAGE = config_dataset.language LANGUAGE = config_dataset.language
@ -87,7 +87,9 @@ model_args = GPTArgs(
gpt_use_masking_gt_prompt_approach=True, gpt_use_masking_gt_prompt_approach=True,
gpt_use_perceiver_resampler=True, gpt_use_perceiver_resampler=True,
) )
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
config = GPTTrainerConfig( config = GPTTrainerConfig(
epochs=1, epochs=1,
output_path=OUT_PATH, output_path=OUT_PATH,

View File

@ -101,7 +101,9 @@ def test_xtts_streaming():
from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts from TTS.tts.models.xtts import Xtts
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
speaker_wav.append(speaker_wav_2)
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1") model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
config = XttsConfig() config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json")) config.load_json(os.path.join(model_path, "config.json"))
@ -131,20 +133,21 @@ def test_xtts_v2():
"""XTTS is too big to run on github actions. We need to test it locally""" """XTTS is too big to run on github actions. We need to test it locally"""
output_path = os.path.join(get_tests_output_path(), "output.wav") output_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
use_gpu = torch.cuda.is_available() use_gpu = torch.cuda.is_available()
if use_gpu: if use_gpu:
run_cli( run_cli(
"yes | " "yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True ' f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
f'--speaker_wav "{speaker_wav}" --language_idx "en"' f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"'
) )
else: else:
run_cli( run_cli(
"yes | " "yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False ' f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
f'--speaker_wav "{speaker_wav}" --language_idx "en"' f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"'
) )
@ -153,7 +156,7 @@ def test_xtts_v2_streaming():
from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts from TTS.tts.models.xtts import Xtts
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2") model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2")
config = XttsConfig() config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json")) config.load_json(os.path.join(model_path, "config.json"))