mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #3149 from coqui-ai/fixup_xtts_v2
Bug fixes and add support for multiples speaker references on XTTS inference
This commit is contained in:
commit
5e992d8704
|
@ -10,34 +10,22 @@
|
|||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
|
||||
],
|
||||
"model_hash": "ae9e4b39e095fd5728fe7f7931eccoqui",
|
||||
"default_vocoder": null,
|
||||
"commit": "480a6cdf7",
|
||||
"license": "CPML",
|
||||
"contact": "info@coqui.ai",
|
||||
"tos_required": true
|
||||
},
|
||||
"xtts_v1": {
|
||||
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
|
||||
"hf_url": [
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/model.pth",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/vocab.json"
|
||||
],
|
||||
"default_vocoder": null,
|
||||
"commit": "e5140314",
|
||||
"license": "CPML",
|
||||
"contact": "info@coqui.ai",
|
||||
"tos_required": true
|
||||
},
|
||||
"xtts_v1.1": {
|
||||
"description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.",
|
||||
"hf_url": [
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/hash.md5"
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/model.pth",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/vocab.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/hash.md5"
|
||||
],
|
||||
"model_hash": "ae9e4b39e095fd5728fe7f7931ec66ad",
|
||||
"model_hash": "7c62beaf58d39b729de287330dc254e7b515677416839b649a50e7cf74c3df59",
|
||||
"default_vocoder": null,
|
||||
"commit": "82910a63",
|
||||
"license": "CPML",
|
||||
|
|
|
@ -2,13 +2,10 @@ import os
|
|||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import torchaudio
|
||||
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
|
||||
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
|
||||
from TTS.tts.models.xtts import load_audio
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
@ -50,31 +47,6 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
|
|||
return rel_clip, rel_clip.shape[-1], cond_idxs
|
||||
|
||||
|
||||
def load_audio(audiopath, sampling_rate):
|
||||
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
|
||||
if audiopath[-4:] == ".mp3":
|
||||
# it uses torchaudio with sox backend to load mp3
|
||||
audio, lsr = torchaudio_sox_load(audiopath)
|
||||
else:
|
||||
# it uses torchaudio soundfile backend to load all the others data type
|
||||
audio, lsr = torchaudio_soundfile_load(audiopath)
|
||||
|
||||
# stereo to mono if needed
|
||||
if audio.size(0) != 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
|
||||
if lsr != sampling_rate:
|
||||
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
||||
|
||||
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||
if torch.any(audio > 10) or not torch.any(audio < 0):
|
||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||
# clip audio invalid values
|
||||
audio.clip_(-1, 1)
|
||||
return audio
|
||||
|
||||
|
||||
class XTTSDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
|
||||
self.config = config
|
||||
|
|
|
@ -238,7 +238,6 @@ class GPTTrainer(BaseTTS):
|
|||
s_info["speaker_wav"],
|
||||
s_info["language"],
|
||||
gpt_cond_len=3,
|
||||
decoder="ne_hifigan",
|
||||
)["wav"]
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
|
||||
|
|
|
@ -67,6 +67,31 @@ def wav_to_mel_cloning(
|
|||
return mel
|
||||
|
||||
|
||||
def load_audio(audiopath, sampling_rate):
|
||||
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
|
||||
if audiopath[-4:] == ".mp3":
|
||||
# it uses torchaudio with sox backend to load mp3
|
||||
audio, lsr = torchaudio.backend.sox_io_backend.load(audiopath)
|
||||
else:
|
||||
# it uses torchaudio soundfile backend to load all the others data type
|
||||
audio, lsr = torchaudio.backend.soundfile_backend.load(audiopath)
|
||||
|
||||
# stereo to mono if needed
|
||||
if audio.size(0) != 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
|
||||
if lsr != sampling_rate:
|
||||
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
||||
|
||||
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||
if torch.any(audio > 10) or not torch.any(audio < 0):
|
||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||
# clip audio invalid values
|
||||
audio.clip_(-1, 1)
|
||||
return audio
|
||||
|
||||
|
||||
def pad_or_truncate(t, length):
|
||||
"""
|
||||
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
|
||||
|
@ -86,78 +111,6 @@ def pad_or_truncate(t, length):
|
|||
return tp
|
||||
|
||||
|
||||
def load_discrete_vocoder_diffuser(
|
||||
trained_diffusion_steps=4000,
|
||||
desired_diffusion_steps=200,
|
||||
cond_free=True,
|
||||
cond_free_k=1,
|
||||
sampler="ddim",
|
||||
):
|
||||
"""
|
||||
Load a GaussianDiffusion instance configured for use as a decoder.
|
||||
|
||||
Args:
|
||||
trained_diffusion_steps (int): The number of diffusion steps used during training.
|
||||
desired_diffusion_steps (int): The number of diffusion steps to use during inference.
|
||||
cond_free (bool): Whether to use a conditioning-free model.
|
||||
cond_free_k (int): The number of samples to use for conditioning-free models.
|
||||
sampler (str): The name of the sampler to use.
|
||||
|
||||
Returns:
|
||||
A SpacedDiffusion instance configured with the given parameters.
|
||||
"""
|
||||
return SpacedDiffusion(
|
||||
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
|
||||
model_mean_type="epsilon",
|
||||
model_var_type="learned_range",
|
||||
loss_type="mse",
|
||||
betas=get_named_beta_schedule("linear", trained_diffusion_steps),
|
||||
conditioning_free=cond_free,
|
||||
conditioning_free_k=cond_free_k,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
|
||||
def do_spectrogram_diffusion(
|
||||
diffusion_model,
|
||||
diffuser,
|
||||
latents,
|
||||
conditioning_latents,
|
||||
temperature=1,
|
||||
):
|
||||
"""
|
||||
Generate a mel-spectrogram using a diffusion model and a diffuser.
|
||||
|
||||
Args:
|
||||
diffusion_model (nn.Module): A diffusion model that converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
||||
diffuser (Diffuser): A diffuser that generates a mel-spectrogram from noise.
|
||||
latents (torch.Tensor): A tensor of shape (batch_size, seq_len, code_size) containing the input spectrogram codes.
|
||||
conditioning_latents (torch.Tensor): A tensor of shape (batch_size, code_size) containing the conditioning codes.
|
||||
temperature (float, optional): The temperature of the noise used by the diffuser. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor of shape (batch_size, mel_channels, mel_seq_len) containing the generated mel-spectrogram.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
output_seq_len = (
|
||||
latents.shape[1] * 4 * 24000 // 22050
|
||||
) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
||||
output_shape = (latents.shape[0], 100, output_seq_len)
|
||||
precomputed_embeddings = diffusion_model.timestep_independent(
|
||||
latents, conditioning_latents, output_seq_len, False
|
||||
)
|
||||
|
||||
noise = torch.randn(output_shape, device=latents.device) * temperature
|
||||
mel = diffuser.sample_loop(
|
||||
diffusion_model,
|
||||
output_shape,
|
||||
noise=noise,
|
||||
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
|
||||
progress=False,
|
||||
)
|
||||
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
|
||||
|
||||
|
||||
@dataclass
|
||||
class XttsAudioConfig(Coqpit):
|
||||
"""
|
||||
|
@ -336,7 +289,7 @@ class Xtts(BaseTTS):
|
|||
"""Compute the conditioning latents for the GPT model from the given audio.
|
||||
|
||||
Args:
|
||||
audio_path (str): Path to the audio file.
|
||||
audio (tensor): audio tensor.
|
||||
sr (int): Sample rate of the audio.
|
||||
length (int): Length of the audio in seconds. Defaults to 3.
|
||||
"""
|
||||
|
@ -404,11 +357,21 @@ class Xtts(BaseTTS):
|
|||
max_ref_length=10,
|
||||
librosa_trim_db=None,
|
||||
sound_norm_refs=False,
|
||||
load_sr=24000,
|
||||
):
|
||||
speaker_embedding = None
|
||||
# deal with multiples references
|
||||
if not isinstance(audio_path, list):
|
||||
audio_paths = [audio_path]
|
||||
else:
|
||||
audio_paths = audio_path
|
||||
|
||||
audio, sr = torchaudio.load(audio_path)
|
||||
audio = audio[:, : sr * max_ref_length].to(self.device)
|
||||
speaker_embeddings = []
|
||||
audios = []
|
||||
speaker_embedding = None
|
||||
for file_path in audio_paths:
|
||||
# load the audio in 24khz to avoid issued with multiple sr references
|
||||
audio = load_audio(file_path, load_sr)
|
||||
audio = audio[:, : load_sr * max_ref_length].to(self.device)
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio.mean(0, keepdim=True)
|
||||
if sound_norm_refs:
|
||||
|
@ -416,8 +379,19 @@ class Xtts(BaseTTS):
|
|||
if librosa_trim_db is not None:
|
||||
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
||||
|
||||
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
||||
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
|
||||
speaker_embedding = self.get_speaker_embedding(audio, load_sr)
|
||||
speaker_embeddings.append(speaker_embedding)
|
||||
|
||||
audios.append(audio)
|
||||
|
||||
# use a merge of all references for gpt cond latents
|
||||
full_audio = torch.cat(audios, dim=-1)
|
||||
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T]
|
||||
|
||||
if speaker_embeddings:
|
||||
speaker_embedding = torch.stack(speaker_embeddings)
|
||||
speaker_embedding = speaker_embedding.mean(dim=0)
|
||||
|
||||
return gpt_cond_latents, speaker_embedding
|
||||
|
||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
||||
|
@ -426,7 +400,7 @@ class Xtts(BaseTTS):
|
|||
Args:
|
||||
text (str): Input text.
|
||||
config (XttsConfig): Config with inference parameters.
|
||||
speaker_wav (str): Path to the speaker audio file for cloning.
|
||||
speaker_wav (list): List of paths to the speaker audio files to be used for cloning.
|
||||
language (str): Language ID of the speaker.
|
||||
**kwargs: Inference settings. See `inference()`.
|
||||
|
||||
|
@ -436,11 +410,6 @@ class Xtts(BaseTTS):
|
|||
as latents used at inference.
|
||||
|
||||
"""
|
||||
|
||||
# Make the synthesizer happy 🥳
|
||||
if isinstance(speaker_wav, list):
|
||||
speaker_wav = speaker_wav[0]
|
||||
|
||||
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)
|
||||
|
||||
def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
|
||||
|
@ -522,27 +491,6 @@ class Xtts(BaseTTS):
|
|||
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
|
||||
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
|
||||
|
||||
decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
|
||||
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
|
||||
Generally a value above 250 is not noticeably better, however. Defaults to 100.
|
||||
|
||||
cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion
|
||||
performs two forward passes for each diffusion step: one with the outputs of the autoregressive model
|
||||
and one with no conditioning priors. The output of the two is blended according to the cond_free_k
|
||||
value below. Conditioning-free diffusion is the real deal, and dramatically improves realism.
|
||||
Defaults to True.
|
||||
|
||||
cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the
|
||||
conditioning-present signal. [0,inf]. As cond_free_k increases, the output becomes dominated by the
|
||||
conditioning-free signal. Defaults to 2.0.
|
||||
|
||||
diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1].
|
||||
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
|
||||
Defaults to 1.0.
|
||||
|
||||
decoder: (str) Selects the decoder to use between ("hifigan", "diffusion")
|
||||
Defaults to hifigan
|
||||
|
||||
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
|
||||
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
|
||||
here: https://huggingface.co/docs/transformers/internal/generation_utils
|
||||
|
@ -569,12 +517,6 @@ class Xtts(BaseTTS):
|
|||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=do_sample,
|
||||
decoder_iterations=decoder_iterations,
|
||||
cond_free=cond_free,
|
||||
cond_free_k=cond_free_k,
|
||||
diffusion_temperature=diffusion_temperature,
|
||||
decoder_sampler=decoder_sampler,
|
||||
decoder=decoder,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
|
@ -592,13 +534,6 @@ class Xtts(BaseTTS):
|
|||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
# Decoder inference
|
||||
decoder_iterations=100,
|
||||
cond_free=True,
|
||||
cond_free_k=2,
|
||||
diffusion_temperature=1.0,
|
||||
decoder_sampler="ddim",
|
||||
decoder="hifigan",
|
||||
num_beams=1,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
|
@ -693,8 +628,6 @@ class Xtts(BaseTTS):
|
|||
top_k=50,
|
||||
top_p=0.85,
|
||||
do_sample=True,
|
||||
# Decoder inference
|
||||
decoder="hifigan",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
text = text.strip().lower()
|
||||
|
|
|
@ -39,6 +39,7 @@ You can also mail us at info@coqui.ai.
|
|||
### Inference
|
||||
#### 🐸TTS API
|
||||
|
||||
##### Single reference
|
||||
```python
|
||||
from TTS.api import TTS
|
||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
|
||||
|
@ -46,12 +47,25 @@ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
|
|||
# generate speech by cloning a voice using default settings
|
||||
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
file_path="output.wav",
|
||||
speaker_wav="/path/to/target/speaker.wav",
|
||||
speaker_wav=["/path/to/target/speaker.wav"],
|
||||
language="en")
|
||||
```
|
||||
|
||||
##### Multiple references
|
||||
```python
|
||||
from TTS.api import TTS
|
||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
|
||||
|
||||
# generate speech by cloning a voice using default settings
|
||||
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
file_path="output.wav",
|
||||
speaker_wav=["/path/to/target/speaker.wav", "/path/to/target/speaker_2.wav", "/path/to/target/speaker_3.wav"],
|
||||
language="en")
|
||||
```
|
||||
|
||||
#### 🐸TTS Command line
|
||||
|
||||
##### Single reference
|
||||
```console
|
||||
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
|
||||
--text "Bugün okula gitmek istemiyorum." \
|
||||
|
@ -60,6 +74,25 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
|
|||
--use_cuda true
|
||||
```
|
||||
|
||||
##### Multiple references
|
||||
```console
|
||||
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
|
||||
--text "Bugün okula gitmek istemiyorum." \
|
||||
--speaker_wav /path/to/target/speaker.wav /path/to/target/speaker_2.wav /path/to/target/speaker_3.wav \
|
||||
--language_idx tr \
|
||||
--use_cuda true
|
||||
```
|
||||
or for all wav files in a directory you can use:
|
||||
|
||||
```console
|
||||
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
|
||||
--text "Bugün okula gitmek istemiyorum." \
|
||||
--speaker_wav /path/to/target/*.wav \
|
||||
--language_idx tr \
|
||||
--use_cuda true
|
||||
```
|
||||
|
||||
|
||||
#### model directly
|
||||
|
||||
If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
|
||||
|
@ -83,7 +116,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
|
|||
model.cuda()
|
||||
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
|
||||
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
|
||||
|
||||
print("Inference...")
|
||||
out = model.inference(
|
||||
|
@ -120,7 +153,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
|
|||
model.cuda()
|
||||
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
|
||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
|
||||
|
||||
print("Inference...")
|
||||
t0 = time.time()
|
||||
|
@ -177,7 +210,7 @@ model.load_checkpoint(config, checkpoint_path=XTTS_CHECKPOINT, vocab_path=TOKENI
|
|||
model.cuda()
|
||||
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=SPEAKER_REFERENCE)
|
||||
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=[SPEAKER_REFERENCE])
|
||||
|
||||
print("Inference...")
|
||||
out = model.inference(
|
||||
|
|
|
@ -41,8 +41,8 @@ os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
|||
|
||||
|
||||
# DVAE files
|
||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/dvae.pth"
|
||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth"
|
||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/dvae.pth"
|
||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/mel_stats.pth"
|
||||
|
||||
# Set the path to the downloaded files
|
||||
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1])
|
||||
|
@ -55,8 +55,8 @@ if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
|||
|
||||
|
||||
# Download XTTS v1.1 checkpoint if needed
|
||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json"
|
||||
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth"
|
||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/vocab.json"
|
||||
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/model.pth"
|
||||
|
||||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
|
||||
|
@ -71,9 +71,9 @@ if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
|||
|
||||
|
||||
# Training sentences generations
|
||||
SPEAKER_REFERENCE = (
|
||||
SPEAKER_REFERENCE = [
|
||||
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
||||
)
|
||||
]
|
||||
LANGUAGE = config_dataset.language
|
||||
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@ CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
|
|||
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||
|
||||
|
||||
# ToDo: update DVAE checkpoint
|
||||
# DVAE files
|
||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/dvae.pth"
|
||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth"
|
||||
|
@ -53,15 +54,14 @@ if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
|||
print(" > Downloading DVAE files!")
|
||||
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||
|
||||
# ToDo: Update links for XTTS v2.0
|
||||
|
||||
# Download XTTS v2.0 checkpoint if needed
|
||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/vocab.json"
|
||||
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/model.pth"
|
||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
|
||||
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"
|
||||
|
||||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
|
||||
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file
|
||||
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
|
||||
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file
|
||||
|
||||
# download XTTS v2.0 files if needed
|
||||
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
||||
|
@ -72,9 +72,9 @@ if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
|||
|
||||
|
||||
# Training sentences generations
|
||||
SPEAKER_REFERENCE = (
|
||||
SPEAKER_REFERENCE = [
|
||||
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
||||
)
|
||||
]
|
||||
LANGUAGE = config_dataset.language
|
||||
|
||||
|
||||
|
@ -90,9 +90,9 @@ def main():
|
|||
dvae_checkpoint=DVAE_CHECKPOINT,
|
||||
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
|
||||
tokenizer_file=TOKENIZER_FILE,
|
||||
gpt_num_audio_tokens=8194,
|
||||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
gpt_num_audio_tokens=1024,
|
||||
gpt_start_audio_token=1025,
|
||||
gpt_stop_audio_token=1026,
|
||||
gpt_use_masking_gt_prompt_approach=True,
|
||||
gpt_use_perceiver_resampler=True,
|
||||
)
|
||||
|
|
|
@ -60,7 +60,7 @@ XTTS_CHECKPOINT = None # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_s
|
|||
|
||||
|
||||
# Training sentences generations
|
||||
SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
||||
SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences
|
||||
LANGUAGE = config_dataset.language
|
||||
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ XTTS_CHECKPOINT = None # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_s
|
|||
|
||||
|
||||
# Training sentences generations
|
||||
SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
||||
SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences
|
||||
LANGUAGE = config_dataset.language
|
||||
|
||||
|
||||
|
@ -87,7 +87,9 @@ model_args = GPTArgs(
|
|||
gpt_use_masking_gt_prompt_approach=True,
|
||||
gpt_use_perceiver_resampler=True,
|
||||
)
|
||||
|
||||
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
|
||||
|
||||
config = GPTTrainerConfig(
|
||||
epochs=1,
|
||||
output_path=OUT_PATH,
|
||||
|
|
|
@ -101,7 +101,9 @@ def test_xtts_streaming():
|
|||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
|
||||
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
|
||||
speaker_wav.append(speaker_wav_2)
|
||||
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
|
||||
config = XttsConfig()
|
||||
config.load_json(os.path.join(model_path, "config.json"))
|
||||
|
@ -131,20 +133,21 @@ def test_xtts_v2():
|
|||
"""XTTS is too big to run on github actions. We need to test it locally"""
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
speaker_wav_2 = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0002.wav")
|
||||
use_gpu = torch.cuda.is_available()
|
||||
if use_gpu:
|
||||
run_cli(
|
||||
"yes | "
|
||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"'
|
||||
)
|
||||
else:
|
||||
run_cli(
|
||||
"yes | "
|
||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
|
||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"'
|
||||
)
|
||||
|
||||
|
||||
|
@ -153,7 +156,7 @@ def test_xtts_v2_streaming():
|
|||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
speaker_wav = [os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")]
|
||||
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2")
|
||||
config = XttsConfig()
|
||||
config.load_json(os.path.join(model_path, "config.json"))
|
||||
|
|
Loading…
Reference in New Issue