Update gradio demo

This commit is contained in:
Edresson Casanova 2023-11-24 14:26:26 -03:00
parent af74cd4426
commit 8967fc7ef2
3 changed files with 35 additions and 11 deletions

View File

@ -44,6 +44,7 @@ def list_files(basePath, validExts=None, contains=None):
yield audioPath yield audioPath
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
audio_total_size = 0
# make sure that ooutput file exists # make sure that ooutput file exists
os.makedirs(out_path, exist_ok=True) os.makedirs(out_path, exist_ok=True)
@ -67,7 +68,9 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
wav = torch.mean(wav, dim=0, keepdim=True) wav = torch.mean(wav, dim=0, keepdim=True)
wav = wav.squeeze() wav = wav.squeeze()
segments, info = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) audio_total_size += (wav.size(-1) / sr)
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
segments = list(segments) segments = list(segments)
i = 0 i = 0
sentence = "" sentence = ""
@ -101,9 +104,9 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
sentence = sentence[1:] sentence = sentence[1:]
# Expand number and abbreviations plus normalization # Expand number and abbreviations plus normalization
sentence = multilingual_cleaners(sentence, target_language) sentence = multilingual_cleaners(sentence, target_language)
audio_file_name, ext = os.path.splitext(os.path.basename(audio_path)) audio_file_name, _ = os.path.splitext(os.path.basename(audio_path))
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}{ext}" audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav"
# Check for the next word's existence # Check for the next word's existence
if word_idx + 1 < len(words_list): if word_idx + 1 < len(words_list):
@ -125,8 +128,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
if audio.size(-1) >= sr/3: if audio.size(-1) >= sr/3:
torchaudio.save(absoulte_path, torchaudio.save(absoulte_path,
audio, audio,
sr, sr
backend="sox",
) )
else: else:
continue continue
@ -150,4 +152,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
df_eval = df_eval.sort_values('audio_file') df_eval = df_eval.sort_values('audio_file')
df_eval.to_csv(eval_metadata_path, sep="|", index=False) df_eval.to_csv(eval_metadata_path, sep="|", index=False)
return train_metadata_path, eval_metadata_path # deallocate VRAM
del asr_model
return train_metadata_path, eval_metadata_path, audio_total_size

View File

@ -164,4 +164,7 @@ def train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path
longest_text_idx = samples_len.index(max(samples_len)) longest_text_idx = samples_len.index(max(samples_len))
speaker_ref = train_samples[longest_text_idx]["audio_file"] speaker_ref = train_samples[longest_text_idx]["audio_file"]
# deallocate VRAM
del model, trainer
return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, speaker_ref return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, speaker_ref

View File

@ -9,17 +9,23 @@ import numpy as np
import os import os
import torch import torch
import torchaudio import torchaudio
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list, list_audios from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts from TTS.tts.models.xtts import Xtts
def clear_gpu_cache():
# clear the GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
PORT = 5003 PORT = 5003
XTTS_MODEL = None XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab): def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
clear_gpu_cache()
global XTTS_MODEL global XTTS_MODEL
config = XttsConfig() config = XttsConfig()
config.load_json(xtts_config) config.load_json(xtts_config)
@ -144,13 +150,23 @@ with gr.Blocks() as demo:
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.") prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
clear_gpu_cache()
out_path = os.path.join(out_path, "dataset") out_path = os.path.join(out_path, "dataset")
os.makedirs(out_path, exist_ok=True) os.makedirs(out_path, exist_ok=True)
if audio_path is None: if audio_path is None:
# ToDo: raise an error # ToDo: raise an error
pass pass
else: else:
train_meta, eval_meta = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
clear_gpu_cache()
# if audio total len is less than 2 minutes raise an error
if audio_total_size < 120:
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
print(message)
return message, " ", " "
print("Dataset Processed!") print("Dataset Processed!")
return "Dataset Processed!", train_meta, eval_meta return "Dataset Processed!", train_meta, eval_meta
@ -173,7 +189,7 @@ with gr.Blocks() as demo:
minimum=2, minimum=2,
maximum=512, maximum=512,
step=1, step=1,
value=15, value=16,
) )
progress_train = gr.Label( progress_train = gr.Label(
label="Progress:" label="Progress:"
@ -186,8 +202,7 @@ with gr.Blocks() as demo:
train_btn = gr.Button(value="Step 2 - Run the training") train_btn = gr.Button(value="Step 2 - Run the training")
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path, progress=gr.Progress(track_tqdm=True)): def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path, progress=gr.Progress(track_tqdm=True)):
# train_csv = '/tmp/tmprh4k_vou/metadata_train.csv' clear_gpu_cache()
# eval_csv = '/tmp/tmprh4k_vou/metadata_eval.csv'
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path=output_path) config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path=output_path)
# copy original files to avoid parameters changes issues # copy original files to avoid parameters changes issues
@ -196,6 +211,7 @@ with gr.Blocks() as demo:
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
print("Model training done!") print("Model training done!")
clear_gpu_cache()
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav