From 626d9e16fb35061c3840276a26c08077f1fab309 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 24 Nov 2023 08:44:21 -0300 Subject: [PATCH] Fix demo freezing issue --- TTS/demos/xtts_ft_demo/requirements.txt | 3 +- TTS/demos/xtts_ft_demo/xtts_demo.py | 71 +++++++++++++------------ 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/TTS/demos/xtts_ft_demo/requirements.txt b/TTS/demos/xtts_ft_demo/requirements.txt index 8360accf..cb5b16f6 100644 --- a/TTS/demos/xtts_ft_demo/requirements.txt +++ b/TTS/demos/xtts_ft_demo/requirements.txt @@ -1 +1,2 @@ -faster_whisper \ No newline at end of file +faster_whisper==0.9.0 +gradio==4.7.1 \ No newline at end of file diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py index 016a929e..6fee1a50 100644 --- a/TTS/demos/xtts_ft_demo/xtts_demo.py +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -28,7 +28,7 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab): model.cuda() return model -def run_tts(lang, tts_text, xtts_checkpoint, xtts_config, xtts_vocab, speaker_audio_file, state_vars): +def run_tts(lang, tts_text, xtts_checkpoint, xtts_config, xtts_vocab, speaker_audio_file): # ToDo: add the load in other function to fast inference model = load_model(xtts_checkpoint, xtts_config, xtts_vocab) gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=model.config.gpt_cond_len, max_ref_length=model.config.max_ref_len, sound_norm_refs=model.config.sound_norm_refs) @@ -95,7 +95,7 @@ def read_logs(): with gr.Blocks() as demo: - state_vars = gr.State() + # state_vars = gr.State() with gr.Tab("Data processing"): upload_file = gr.Audio( sources="upload", @@ -135,7 +135,7 @@ with gr.Blocks() as demo: prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.") - def preprocess_dataset(audio_path, language, state_vars, progress=gr.Progress(track_tqdm=True)): + def preprocess_dataset(audio_path, language, progress=gr.Progress(track_tqdm=True)): # create a temp directory to save the dataset out_path = tempfile.TemporaryDirectory().name if audio_path is None: @@ -144,27 +144,15 @@ with gr.Blocks() as demo: else: train_meta, eval_meta = format_audio_list([audio_path], target_language=language, out_path=out_path, gradio_progress=progress) - state_vars = {} - state_vars["train_csv"] = train_meta - state_vars["eval_csv"] = eval_meta - print(state_vars) - return "Dataset Processed!", state_vars - - prompt_compute_btn.click( - fn=preprocess_dataset, - inputs=[ - upload_file, - lang, - state_vars, - ], - outputs=[ - progress_data, - state_vars, - ], - ) - + return "Dataset Processed!", train_meta, eval_meta with gr.Tab("Fine-tuning XTTS"): + train_csv = gr.Textbox( + label="Train CSV:", + ) + eval_csv = gr.Textbox( + label="Eval CSV:", + ) num_epochs = gr.Slider( label="num_epochs", minimum=1, @@ -189,21 +177,22 @@ with gr.Blocks() as demo: demo.load(read_logs, None, logs_tts_train, every=1) train_btn = gr.Button(value="Step 2 - Run the training") - def train_model(language, num_epochs, batch_size, state_vars, output_path="./", progress=gr.Progress(track_tqdm=True)): - # state_vars = {'train_csv': '/tmp/tmprh4k_vou/metadata_train.csv', 'eval_csv': '/tmp/tmprh4k_vou/metadata_eval.csv'} + def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path="./", progress=gr.Progress(track_tqdm=True)): + # train_csv = '/tmp/tmprh4k_vou/metadata_train.csv' + # eval_csv = '/tmp/tmprh4k_vou/metadata_eval.csv' - config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, state_vars["train_csv"], state_vars["eval_csv"], output_path=output_path) + config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path=output_path) # copy original files to avoid parameters changes issues os.system(f"cp {config_path} {exp_path}") os.system(f"cp {vocab_file} {exp_path}") ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") - state_vars["config_path"] = config_path - state_vars["original_xtts_checkpoint"] = original_xtts_checkpoint - state_vars["vocab_file"] = vocab_file - state_vars["ft_xtts_checkpoint"] = ft_xtts_checkpoint - state_vars["speaker_audio_file"] = speaker_wav - return "Model training done!", state_vars, config_path, vocab_file, ft_xtts_checkpoint, speaker_wav + # state_vars["config_path"] = config_path + # state_vars["original_xtts_checkpoint"] = original_xtts_checkpoint + # state_vars["vocab_file"] = vocab_file + # state_vars["ft_xtts_checkpoint"] = ft_xtts_checkpoint + # state_vars["speaker_audio_file"] = speaker_wav + return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav with gr.Tab("Inference"): @@ -254,16 +243,31 @@ with gr.Blocks() as demo: tts_output_audio = gr.Audio(label="Generated Audio.") reference_audio = gr.Audio(label="Reference audio used.") + prompt_compute_btn.click( + fn=preprocess_dataset, + inputs=[ + upload_file, + lang, + ], + outputs=[ + progress_data, + train_csv, + eval_csv, + ], + ) + + train_btn.click( fn=train_model, inputs=[ lang, + train_csv, + eval_csv, num_epochs, batch_size, - state_vars, ], - outputs=[progress_train, state_vars, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], + outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], ) @@ -276,7 +280,6 @@ with gr.Blocks() as demo: xtts_config, xtts_vocab, speaker_reference_audio, - state_vars, ], outputs=[tts_output_audio, reference_audio], )