mirror of https://github.com/coqui-ai/TTS.git
Update demo
This commit is contained in:
parent
626d9e16fb
commit
fa9bb26ebb
|
@ -8,8 +8,6 @@ from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
|
|
||||||
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
|
|
||||||
# torch.set_num_threads(1)
|
# torch.set_num_threads(1)
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
|
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
|
||||||
|
@ -45,7 +43,7 @@ def list_files(basePath, validExts=None, contains=None):
|
||||||
audioPath = os.path.join(rootDir, filename)
|
audioPath = os.path.join(rootDir, filename)
|
||||||
yield audioPath
|
yield audioPath
|
||||||
|
|
||||||
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.5, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
||||||
# make sure that ooutput file exists
|
# make sure that ooutput file exists
|
||||||
os.makedirs(out_path, exist_ok=True)
|
os.makedirs(out_path, exist_ok=True)
|
||||||
|
|
||||||
|
@ -121,10 +119,10 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
|
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
|
||||||
# if the audio is too short ignore it (i.e < 0.33 seconds)
|
# if the audio is too short ignore it (i.e < 0.33 seconds)
|
||||||
if audio.size(-1) >= sr/3:
|
if audio.size(-1) >= sr/3:
|
||||||
torchaudio.backend.sox_io_backend.save(
|
torchaudio.save(absoulte_path,
|
||||||
absoulte_path,
|
|
||||||
audio,
|
audio,
|
||||||
sr
|
sr,
|
||||||
|
backend="sox",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -159,5 +159,9 @@ def train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path
|
||||||
)
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
||||||
|
# get the longest text audio file to use as speaker reference
|
||||||
|
samples_len = [len(item["text"].split(" ")) for item in train_samples]
|
||||||
|
longest_text_idx = samples_len.index(max(samples_len))
|
||||||
|
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
||||||
|
|
||||||
return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, train_samples[0]["audio_file"]
|
return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, speaker_ref
|
||||||
|
|
|
@ -18,31 +18,32 @@ from TTS.tts.models.xtts import Xtts
|
||||||
|
|
||||||
PORT = 5003
|
PORT = 5003
|
||||||
|
|
||||||
|
XTTS_MODEL = None
|
||||||
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||||
|
global XTTS_MODEL
|
||||||
config = XttsConfig()
|
config = XttsConfig()
|
||||||
config.load_json(xtts_config)
|
config.load_json(xtts_config)
|
||||||
model = Xtts.init_from_config(config)
|
XTTS_MODEL = Xtts.init_from_config(config)
|
||||||
print("Loading XTTS model! ")
|
print("Loading XTTS model! ")
|
||||||
model.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False)
|
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
model.cuda()
|
XTTS_MODEL.cuda()
|
||||||
return model
|
|
||||||
|
|
||||||
def run_tts(lang, tts_text, xtts_checkpoint, xtts_config, xtts_vocab, speaker_audio_file):
|
print("Model Loaded!")
|
||||||
# ToDo: add the load in other function to fast inference
|
return "Model Loaded!"
|
||||||
model = load_model(xtts_checkpoint, xtts_config, xtts_vocab)
|
|
||||||
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=model.config.gpt_cond_len, max_ref_length=model.config.max_ref_len, sound_norm_refs=model.config.sound_norm_refs)
|
def run_tts(lang, tts_text, speaker_audio_file):
|
||||||
speaker_embedding
|
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
|
||||||
out = model.inference(
|
out = XTTS_MODEL.inference(
|
||||||
text=tts_text,
|
text=tts_text,
|
||||||
language=lang,
|
language=lang,
|
||||||
gpt_cond_latent=gpt_cond_latent,
|
gpt_cond_latent=gpt_cond_latent,
|
||||||
speaker_embedding=speaker_embedding,
|
speaker_embedding=speaker_embedding,
|
||||||
temperature=model.config.temperature, # Add custom parameters here
|
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
||||||
length_penalty=model.config.length_penalty,
|
length_penalty=XTTS_MODEL.config.length_penalty,
|
||||||
repetition_penalty=model.config.repetition_penalty,
|
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
|
||||||
top_k=model.config.top_k,
|
top_k=XTTS_MODEL.config.top_k,
|
||||||
top_p=model.config.top_p,
|
top_p=XTTS_MODEL.config.top_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
||||||
|
@ -95,12 +96,19 @@ def read_logs():
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
# state_vars = gr.State()
|
|
||||||
with gr.Tab("Data processing"):
|
with gr.Tab("Data processing"):
|
||||||
upload_file = gr.Audio(
|
out_path = gr.Textbox(
|
||||||
sources="upload",
|
label="Output path (where data and checkpoints will be saved):",
|
||||||
label="Select here the audio files that you want to use for XTTS trainining !",
|
value="/tmp/xtts_ft/"
|
||||||
type="filepath",
|
)
|
||||||
|
# upload_file = gr.Audio(
|
||||||
|
# sources="upload",
|
||||||
|
# label="Select here the audio files that you want to use for XTTS trainining !",
|
||||||
|
# type="filepath",
|
||||||
|
# )
|
||||||
|
upload_file = gr.File(
|
||||||
|
file_count="multiple",
|
||||||
|
label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)",
|
||||||
)
|
)
|
||||||
lang = gr.Dropdown(
|
lang = gr.Dropdown(
|
||||||
label="Dataset Language",
|
label="Dataset Language",
|
||||||
|
@ -135,18 +143,18 @@ with gr.Blocks() as demo:
|
||||||
|
|
||||||
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
|
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
|
||||||
|
|
||||||
def preprocess_dataset(audio_path, language, progress=gr.Progress(track_tqdm=True)):
|
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
|
||||||
# create a temp directory to save the dataset
|
out_path = os.path.join(out_path, "dataset")
|
||||||
out_path = tempfile.TemporaryDirectory().name
|
os.makedirs(out_path, exist_ok=True)
|
||||||
if audio_path is None:
|
if audio_path is None:
|
||||||
# ToDo: raise an error
|
# ToDo: raise an error
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
train_meta, eval_meta = format_audio_list([audio_path], target_language=language, out_path=out_path, gradio_progress=progress)
|
train_meta, eval_meta = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
||||||
|
print("Dataset Processed!")
|
||||||
return "Dataset Processed!", train_meta, eval_meta
|
return "Dataset Processed!", train_meta, eval_meta
|
||||||
|
|
||||||
with gr.Tab("Fine-tuning XTTS"):
|
with gr.Tab("Fine-tuning XTTS Encoder"):
|
||||||
train_csv = gr.Textbox(
|
train_csv = gr.Textbox(
|
||||||
label="Train CSV:",
|
label="Train CSV:",
|
||||||
)
|
)
|
||||||
|
@ -158,7 +166,7 @@ with gr.Blocks() as demo:
|
||||||
minimum=1,
|
minimum=1,
|
||||||
maximum=100,
|
maximum=100,
|
||||||
step=1,
|
step=1,
|
||||||
value=2,# 15
|
value=10,
|
||||||
)
|
)
|
||||||
batch_size = gr.Slider(
|
batch_size = gr.Slider(
|
||||||
label="batch_size",
|
label="batch_size",
|
||||||
|
@ -177,7 +185,7 @@ with gr.Blocks() as demo:
|
||||||
demo.load(read_logs, None, logs_tts_train, every=1)
|
demo.load(read_logs, None, logs_tts_train, every=1)
|
||||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||||
|
|
||||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path="./", progress=gr.Progress(track_tqdm=True)):
|
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path, progress=gr.Progress(track_tqdm=True)):
|
||||||
# train_csv = '/tmp/tmprh4k_vou/metadata_train.csv'
|
# train_csv = '/tmp/tmprh4k_vou/metadata_train.csv'
|
||||||
# eval_csv = '/tmp/tmprh4k_vou/metadata_eval.csv'
|
# eval_csv = '/tmp/tmprh4k_vou/metadata_eval.csv'
|
||||||
|
|
||||||
|
@ -187,15 +195,13 @@ with gr.Blocks() as demo:
|
||||||
os.system(f"cp {vocab_file} {exp_path}")
|
os.system(f"cp {vocab_file} {exp_path}")
|
||||||
|
|
||||||
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
|
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
|
||||||
# state_vars["config_path"] = config_path
|
print("Model training done!")
|
||||||
# state_vars["original_xtts_checkpoint"] = original_xtts_checkpoint
|
|
||||||
# state_vars["vocab_file"] = vocab_file
|
|
||||||
# state_vars["ft_xtts_checkpoint"] = ft_xtts_checkpoint
|
|
||||||
# state_vars["speaker_audio_file"] = speaker_wav
|
|
||||||
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
||||||
|
|
||||||
|
|
||||||
with gr.Tab("Inference"):
|
with gr.Tab("Inference"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column() as col1:
|
||||||
xtts_checkpoint = gr.Textbox(
|
xtts_checkpoint = gr.Textbox(
|
||||||
label="XTTS checkpoint path:",
|
label="XTTS checkpoint path:",
|
||||||
value="",
|
value="",
|
||||||
|
@ -208,6 +214,12 @@ with gr.Blocks() as demo:
|
||||||
label="XTTS config path:",
|
label="XTTS config path:",
|
||||||
value="",
|
value="",
|
||||||
)
|
)
|
||||||
|
progress_load = gr.Label(
|
||||||
|
label="Progress:"
|
||||||
|
)
|
||||||
|
load_btn = gr.Button(value="Step 3 - Load Fine tuned XTTS model")
|
||||||
|
|
||||||
|
with gr.Column() as col2:
|
||||||
speaker_reference_audio = gr.Textbox(
|
speaker_reference_audio = gr.Textbox(
|
||||||
label="Speaker reference audio:",
|
label="Speaker reference audio:",
|
||||||
value="",
|
value="",
|
||||||
|
@ -238,8 +250,9 @@ with gr.Blocks() as demo:
|
||||||
label="Input Text.",
|
label="Input Text.",
|
||||||
value="This model sounds really good and above all, it's reasonably fast.",
|
value="This model sounds really good and above all, it's reasonably fast.",
|
||||||
)
|
)
|
||||||
tts_btn = gr.Button(value="Step 3 - Inference XTTS model")
|
tts_btn = gr.Button(value="Step 4 - Inference")
|
||||||
|
|
||||||
|
with gr.Column() as col3:
|
||||||
tts_output_audio = gr.Audio(label="Generated Audio.")
|
tts_output_audio = gr.Audio(label="Generated Audio.")
|
||||||
reference_audio = gr.Audio(label="Reference audio used.")
|
reference_audio = gr.Audio(label="Reference audio used.")
|
||||||
|
|
||||||
|
@ -248,6 +261,7 @@ with gr.Blocks() as demo:
|
||||||
inputs=[
|
inputs=[
|
||||||
upload_file,
|
upload_file,
|
||||||
lang,
|
lang,
|
||||||
|
out_path,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
progress_data,
|
progress_data,
|
||||||
|
@ -257,7 +271,6 @@ with gr.Blocks() as demo:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
train_btn.click(
|
train_btn.click(
|
||||||
fn=train_model,
|
fn=train_model,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
@ -266,19 +279,26 @@ with gr.Blocks() as demo:
|
||||||
eval_csv,
|
eval_csv,
|
||||||
num_epochs,
|
num_epochs,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
out_path,
|
||||||
],
|
],
|
||||||
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
load_btn.click(
|
||||||
|
fn=load_model,
|
||||||
|
inputs=[
|
||||||
|
xtts_checkpoint,
|
||||||
|
xtts_config,
|
||||||
|
xtts_vocab
|
||||||
|
],
|
||||||
|
outputs=[progress_load],
|
||||||
|
)
|
||||||
|
|
||||||
tts_btn.click(
|
tts_btn.click(
|
||||||
fn=run_tts,
|
fn=run_tts,
|
||||||
inputs=[
|
inputs=[
|
||||||
tts_language,
|
tts_language,
|
||||||
tts_text,
|
tts_text,
|
||||||
xtts_checkpoint,
|
|
||||||
xtts_config,
|
|
||||||
xtts_vocab,
|
|
||||||
speaker_reference_audio,
|
speaker_reference_audio,
|
||||||
],
|
],
|
||||||
outputs=[tts_output_audio, reference_audio],
|
outputs=[tts_output_audio, reference_audio],
|
||||||
|
|
Loading…
Reference in New Issue