From 8967fc7ef2de54c3c00e532753d72e66cb1406c7 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 24 Nov 2023 14:26:26 -0300 Subject: [PATCH] Update gradio demo --- TTS/demos/xtts_ft_demo/utils/formatter.py | 17 +++++++++------ TTS/demos/xtts_ft_demo/utils/gpt_train.py | 3 +++ TTS/demos/xtts_ft_demo/xtts_demo.py | 26 ++++++++++++++++++----- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/TTS/demos/xtts_ft_demo/utils/formatter.py b/TTS/demos/xtts_ft_demo/utils/formatter.py index 6497b0d7..e49d2426 100644 --- a/TTS/demos/xtts_ft_demo/utils/formatter.py +++ b/TTS/demos/xtts_ft_demo/utils/formatter.py @@ -44,6 +44,7 @@ def list_files(basePath, validExts=None, contains=None): yield audioPath def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): + audio_total_size = 0 # make sure that ooutput file exists os.makedirs(out_path, exist_ok=True) @@ -67,7 +68,9 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 wav = torch.mean(wav, dim=0, keepdim=True) wav = wav.squeeze() - segments, info = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) + audio_total_size += (wav.size(-1) / sr) + + segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) segments = list(segments) i = 0 sentence = "" @@ -101,9 +104,9 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 sentence = sentence[1:] # Expand number and abbreviations plus normalization sentence = multilingual_cleaners(sentence, target_language) - audio_file_name, ext = os.path.splitext(os.path.basename(audio_path)) + audio_file_name, _ = os.path.splitext(os.path.basename(audio_path)) - audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}{ext}" + audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav" # Check for the next word's existence if word_idx + 1 < len(words_list): @@ -125,8 +128,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 if audio.size(-1) >= sr/3: torchaudio.save(absoulte_path, audio, - sr, - backend="sox", + sr ) else: continue @@ -150,4 +152,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 df_eval = df_eval.sort_values('audio_file') df_eval.to_csv(eval_metadata_path, sep="|", index=False) - return train_metadata_path, eval_metadata_path \ No newline at end of file + # deallocate VRAM + del asr_model + + return train_metadata_path, eval_metadata_path, audio_total_size \ No newline at end of file diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py index 1e7d5f36..4d33d6fc 100644 --- a/TTS/demos/xtts_ft_demo/utils/gpt_train.py +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -164,4 +164,7 @@ def train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path longest_text_idx = samples_len.index(max(samples_len)) speaker_ref = train_samples[longest_text_idx]["audio_file"] + # deallocate VRAM + del model, trainer + return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, speaker_ref diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py index 9dcaefce..24a449ec 100644 --- a/TTS/demos/xtts_ft_demo/xtts_demo.py +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -9,17 +9,23 @@ import numpy as np import os import torch import torchaudio -from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list, list_audios +from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts +def clear_gpu_cache(): + # clear the GPU cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + PORT = 5003 XTTS_MODEL = None def load_model(xtts_checkpoint, xtts_config, xtts_vocab): + clear_gpu_cache() global XTTS_MODEL config = XttsConfig() config.load_json(xtts_config) @@ -144,13 +150,23 @@ with gr.Blocks() as demo: prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.") def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): + clear_gpu_cache() out_path = os.path.join(out_path, "dataset") os.makedirs(out_path, exist_ok=True) if audio_path is None: # ToDo: raise an error pass else: - train_meta, eval_meta = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) + train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) + + clear_gpu_cache() + + # if audio total len is less than 2 minutes raise an error + if audio_total_size < 120: + message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" + print(message) + return message, " ", " " + print("Dataset Processed!") return "Dataset Processed!", train_meta, eval_meta @@ -173,7 +189,7 @@ with gr.Blocks() as demo: minimum=2, maximum=512, step=1, - value=15, + value=16, ) progress_train = gr.Label( label="Progress:" @@ -186,8 +202,7 @@ with gr.Blocks() as demo: train_btn = gr.Button(value="Step 2 - Run the training") def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path, progress=gr.Progress(track_tqdm=True)): - # train_csv = '/tmp/tmprh4k_vou/metadata_train.csv' - # eval_csv = '/tmp/tmprh4k_vou/metadata_eval.csv' + clear_gpu_cache() config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path=output_path) # copy original files to avoid parameters changes issues @@ -196,6 +211,7 @@ with gr.Blocks() as demo: ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") print("Model training done!") + clear_gpu_cache() return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav