mirror of https://github.com/coqui-ai/TTS.git
Fix demo freezing issue
This commit is contained in:
parent
7cc348ed76
commit
626d9e16fb
|
@ -1 +1,2 @@
|
||||||
faster_whisper
|
faster_whisper==0.9.0
|
||||||
|
gradio==4.7.1
|
|
@ -28,7 +28,7 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||||
model.cuda()
|
model.cuda()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def run_tts(lang, tts_text, xtts_checkpoint, xtts_config, xtts_vocab, speaker_audio_file, state_vars):
|
def run_tts(lang, tts_text, xtts_checkpoint, xtts_config, xtts_vocab, speaker_audio_file):
|
||||||
# ToDo: add the load in other function to fast inference
|
# ToDo: add the load in other function to fast inference
|
||||||
model = load_model(xtts_checkpoint, xtts_config, xtts_vocab)
|
model = load_model(xtts_checkpoint, xtts_config, xtts_vocab)
|
||||||
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=model.config.gpt_cond_len, max_ref_length=model.config.max_ref_len, sound_norm_refs=model.config.sound_norm_refs)
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=model.config.gpt_cond_len, max_ref_length=model.config.max_ref_len, sound_norm_refs=model.config.sound_norm_refs)
|
||||||
|
@ -95,7 +95,7 @@ def read_logs():
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
state_vars = gr.State()
|
# state_vars = gr.State()
|
||||||
with gr.Tab("Data processing"):
|
with gr.Tab("Data processing"):
|
||||||
upload_file = gr.Audio(
|
upload_file = gr.Audio(
|
||||||
sources="upload",
|
sources="upload",
|
||||||
|
@ -135,7 +135,7 @@ with gr.Blocks() as demo:
|
||||||
|
|
||||||
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
|
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
|
||||||
|
|
||||||
def preprocess_dataset(audio_path, language, state_vars, progress=gr.Progress(track_tqdm=True)):
|
def preprocess_dataset(audio_path, language, progress=gr.Progress(track_tqdm=True)):
|
||||||
# create a temp directory to save the dataset
|
# create a temp directory to save the dataset
|
||||||
out_path = tempfile.TemporaryDirectory().name
|
out_path = tempfile.TemporaryDirectory().name
|
||||||
if audio_path is None:
|
if audio_path is None:
|
||||||
|
@ -144,27 +144,15 @@ with gr.Blocks() as demo:
|
||||||
else:
|
else:
|
||||||
train_meta, eval_meta = format_audio_list([audio_path], target_language=language, out_path=out_path, gradio_progress=progress)
|
train_meta, eval_meta = format_audio_list([audio_path], target_language=language, out_path=out_path, gradio_progress=progress)
|
||||||
|
|
||||||
state_vars = {}
|
return "Dataset Processed!", train_meta, eval_meta
|
||||||
state_vars["train_csv"] = train_meta
|
|
||||||
state_vars["eval_csv"] = eval_meta
|
|
||||||
print(state_vars)
|
|
||||||
return "Dataset Processed!", state_vars
|
|
||||||
|
|
||||||
prompt_compute_btn.click(
|
|
||||||
fn=preprocess_dataset,
|
|
||||||
inputs=[
|
|
||||||
upload_file,
|
|
||||||
lang,
|
|
||||||
state_vars,
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
progress_data,
|
|
||||||
state_vars,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
with gr.Tab("Fine-tuning XTTS"):
|
with gr.Tab("Fine-tuning XTTS"):
|
||||||
|
train_csv = gr.Textbox(
|
||||||
|
label="Train CSV:",
|
||||||
|
)
|
||||||
|
eval_csv = gr.Textbox(
|
||||||
|
label="Eval CSV:",
|
||||||
|
)
|
||||||
num_epochs = gr.Slider(
|
num_epochs = gr.Slider(
|
||||||
label="num_epochs",
|
label="num_epochs",
|
||||||
minimum=1,
|
minimum=1,
|
||||||
|
@ -189,21 +177,22 @@ with gr.Blocks() as demo:
|
||||||
demo.load(read_logs, None, logs_tts_train, every=1)
|
demo.load(read_logs, None, logs_tts_train, every=1)
|
||||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||||
|
|
||||||
def train_model(language, num_epochs, batch_size, state_vars, output_path="./", progress=gr.Progress(track_tqdm=True)):
|
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path="./", progress=gr.Progress(track_tqdm=True)):
|
||||||
# state_vars = {'train_csv': '/tmp/tmprh4k_vou/metadata_train.csv', 'eval_csv': '/tmp/tmprh4k_vou/metadata_eval.csv'}
|
# train_csv = '/tmp/tmprh4k_vou/metadata_train.csv'
|
||||||
|
# eval_csv = '/tmp/tmprh4k_vou/metadata_eval.csv'
|
||||||
|
|
||||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, state_vars["train_csv"], state_vars["eval_csv"], output_path=output_path)
|
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path=output_path)
|
||||||
# copy original files to avoid parameters changes issues
|
# copy original files to avoid parameters changes issues
|
||||||
os.system(f"cp {config_path} {exp_path}")
|
os.system(f"cp {config_path} {exp_path}")
|
||||||
os.system(f"cp {vocab_file} {exp_path}")
|
os.system(f"cp {vocab_file} {exp_path}")
|
||||||
|
|
||||||
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
|
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
|
||||||
state_vars["config_path"] = config_path
|
# state_vars["config_path"] = config_path
|
||||||
state_vars["original_xtts_checkpoint"] = original_xtts_checkpoint
|
# state_vars["original_xtts_checkpoint"] = original_xtts_checkpoint
|
||||||
state_vars["vocab_file"] = vocab_file
|
# state_vars["vocab_file"] = vocab_file
|
||||||
state_vars["ft_xtts_checkpoint"] = ft_xtts_checkpoint
|
# state_vars["ft_xtts_checkpoint"] = ft_xtts_checkpoint
|
||||||
state_vars["speaker_audio_file"] = speaker_wav
|
# state_vars["speaker_audio_file"] = speaker_wav
|
||||||
return "Model training done!", state_vars, config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
||||||
|
|
||||||
|
|
||||||
with gr.Tab("Inference"):
|
with gr.Tab("Inference"):
|
||||||
|
@ -254,16 +243,31 @@ with gr.Blocks() as demo:
|
||||||
tts_output_audio = gr.Audio(label="Generated Audio.")
|
tts_output_audio = gr.Audio(label="Generated Audio.")
|
||||||
reference_audio = gr.Audio(label="Reference audio used.")
|
reference_audio = gr.Audio(label="Reference audio used.")
|
||||||
|
|
||||||
|
prompt_compute_btn.click(
|
||||||
|
fn=preprocess_dataset,
|
||||||
|
inputs=[
|
||||||
|
upload_file,
|
||||||
|
lang,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
progress_data,
|
||||||
|
train_csv,
|
||||||
|
eval_csv,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
train_btn.click(
|
train_btn.click(
|
||||||
fn=train_model,
|
fn=train_model,
|
||||||
inputs=[
|
inputs=[
|
||||||
lang,
|
lang,
|
||||||
|
train_csv,
|
||||||
|
eval_csv,
|
||||||
num_epochs,
|
num_epochs,
|
||||||
batch_size,
|
batch_size,
|
||||||
state_vars,
|
|
||||||
],
|
],
|
||||||
outputs=[progress_train, state_vars, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -276,7 +280,6 @@ with gr.Blocks() as demo:
|
||||||
xtts_config,
|
xtts_config,
|
||||||
xtts_vocab,
|
xtts_vocab,
|
||||||
speaker_reference_audio,
|
speaker_reference_audio,
|
||||||
state_vars,
|
|
||||||
],
|
],
|
||||||
outputs=[tts_output_audio, reference_audio],
|
outputs=[tts_output_audio, reference_audio],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue