Update gradio demo

This commit is contained in:
Edresson Casanova 2023-11-24 14:26:26 -03:00
parent af74cd4426
commit 8967fc7ef2
3 changed files with 35 additions and 11 deletions

View File

@ -44,6 +44,7 @@ def list_files(basePath, validExts=None, contains=None):
yield audioPath
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
audio_total_size = 0
# make sure that ooutput file exists
os.makedirs(out_path, exist_ok=True)
@ -67,7 +68,9 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
wav = torch.mean(wav, dim=0, keepdim=True)
wav = wav.squeeze()
segments, info = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
audio_total_size += (wav.size(-1) / sr)
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
segments = list(segments)
i = 0
sentence = ""
@ -101,9 +104,9 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
sentence = sentence[1:]
# Expand number and abbreviations plus normalization
sentence = multilingual_cleaners(sentence, target_language)
audio_file_name, ext = os.path.splitext(os.path.basename(audio_path))
audio_file_name, _ = os.path.splitext(os.path.basename(audio_path))
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}{ext}"
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav"
# Check for the next word's existence
if word_idx + 1 < len(words_list):
@ -125,8 +128,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
if audio.size(-1) >= sr/3:
torchaudio.save(absoulte_path,
audio,
sr,
backend="sox",
sr
)
else:
continue
@ -150,4 +152,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
df_eval = df_eval.sort_values('audio_file')
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
return train_metadata_path, eval_metadata_path
# deallocate VRAM
del asr_model
return train_metadata_path, eval_metadata_path, audio_total_size

View File

@ -164,4 +164,7 @@ def train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path
longest_text_idx = samples_len.index(max(samples_len))
speaker_ref = train_samples[longest_text_idx]["audio_file"]
# deallocate VRAM
del model, trainer
return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, speaker_ref

View File

@ -9,17 +9,23 @@ import numpy as np
import os
import torch
import torchaudio
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list, list_audios
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
def clear_gpu_cache():
# clear the GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
PORT = 5003
XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
clear_gpu_cache()
global XTTS_MODEL
config = XttsConfig()
config.load_json(xtts_config)
@ -144,13 +150,23 @@ with gr.Blocks() as demo:
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
clear_gpu_cache()
out_path = os.path.join(out_path, "dataset")
os.makedirs(out_path, exist_ok=True)
if audio_path is None:
# ToDo: raise an error
pass
else:
train_meta, eval_meta = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
clear_gpu_cache()
# if audio total len is less than 2 minutes raise an error
if audio_total_size < 120:
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
print(message)
return message, " ", " "
print("Dataset Processed!")
return "Dataset Processed!", train_meta, eval_meta
@ -173,7 +189,7 @@ with gr.Blocks() as demo:
minimum=2,
maximum=512,
step=1,
value=15,
value=16,
)
progress_train = gr.Label(
label="Progress:"
@ -186,8 +202,7 @@ with gr.Blocks() as demo:
train_btn = gr.Button(value="Step 2 - Run the training")
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path, progress=gr.Progress(track_tqdm=True)):
# train_csv = '/tmp/tmprh4k_vou/metadata_train.csv'
# eval_csv = '/tmp/tmprh4k_vou/metadata_eval.csv'
clear_gpu_cache()
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path=output_path)
# copy original files to avoid parameters changes issues
@ -196,6 +211,7 @@ with gr.Blocks() as demo:
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
print("Model training done!")
clear_gpu_cache()
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav