mirror of https://github.com/coqui-ai/TTS.git
Update gradio demo
This commit is contained in:
parent
af74cd4426
commit
8967fc7ef2
|
@ -44,6 +44,7 @@ def list_files(basePath, validExts=None, contains=None):
|
|||
yield audioPath
|
||||
|
||||
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
||||
audio_total_size = 0
|
||||
# make sure that ooutput file exists
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
|
||||
|
@ -67,7 +68,9 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
wav = torch.mean(wav, dim=0, keepdim=True)
|
||||
|
||||
wav = wav.squeeze()
|
||||
segments, info = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
||||
audio_total_size += (wav.size(-1) / sr)
|
||||
|
||||
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
||||
segments = list(segments)
|
||||
i = 0
|
||||
sentence = ""
|
||||
|
@ -101,9 +104,9 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
sentence = sentence[1:]
|
||||
# Expand number and abbreviations plus normalization
|
||||
sentence = multilingual_cleaners(sentence, target_language)
|
||||
audio_file_name, ext = os.path.splitext(os.path.basename(audio_path))
|
||||
audio_file_name, _ = os.path.splitext(os.path.basename(audio_path))
|
||||
|
||||
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}{ext}"
|
||||
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav"
|
||||
|
||||
# Check for the next word's existence
|
||||
if word_idx + 1 < len(words_list):
|
||||
|
@ -125,8 +128,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
if audio.size(-1) >= sr/3:
|
||||
torchaudio.save(absoulte_path,
|
||||
audio,
|
||||
sr,
|
||||
backend="sox",
|
||||
sr
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
@ -150,4 +152,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
df_eval = df_eval.sort_values('audio_file')
|
||||
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
|
||||
|
||||
return train_metadata_path, eval_metadata_path
|
||||
# deallocate VRAM
|
||||
del asr_model
|
||||
|
||||
return train_metadata_path, eval_metadata_path, audio_total_size
|
|
@ -164,4 +164,7 @@ def train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path
|
|||
longest_text_idx = samples_len.index(max(samples_len))
|
||||
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
||||
|
||||
# deallocate VRAM
|
||||
del model, trainer
|
||||
|
||||
return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, speaker_ref
|
||||
|
|
|
@ -9,17 +9,23 @@ import numpy as np
|
|||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list, list_audios
|
||||
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
|
||||
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
|
||||
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
|
||||
def clear_gpu_cache():
|
||||
# clear the GPU cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
PORT = 5003
|
||||
|
||||
XTTS_MODEL = None
|
||||
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||
clear_gpu_cache()
|
||||
global XTTS_MODEL
|
||||
config = XttsConfig()
|
||||
config.load_json(xtts_config)
|
||||
|
@ -144,13 +150,23 @@ with gr.Blocks() as demo:
|
|||
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
|
||||
|
||||
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
|
||||
clear_gpu_cache()
|
||||
out_path = os.path.join(out_path, "dataset")
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
if audio_path is None:
|
||||
# ToDo: raise an error
|
||||
pass
|
||||
else:
|
||||
train_meta, eval_meta = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
||||
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
||||
|
||||
clear_gpu_cache()
|
||||
|
||||
# if audio total len is less than 2 minutes raise an error
|
||||
if audio_total_size < 120:
|
||||
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
|
||||
print(message)
|
||||
return message, " ", " "
|
||||
|
||||
print("Dataset Processed!")
|
||||
return "Dataset Processed!", train_meta, eval_meta
|
||||
|
||||
|
@ -173,7 +189,7 @@ with gr.Blocks() as demo:
|
|||
minimum=2,
|
||||
maximum=512,
|
||||
step=1,
|
||||
value=15,
|
||||
value=16,
|
||||
)
|
||||
progress_train = gr.Label(
|
||||
label="Progress:"
|
||||
|
@ -186,8 +202,7 @@ with gr.Blocks() as demo:
|
|||
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||
|
||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path, progress=gr.Progress(track_tqdm=True)):
|
||||
# train_csv = '/tmp/tmprh4k_vou/metadata_train.csv'
|
||||
# eval_csv = '/tmp/tmprh4k_vou/metadata_eval.csv'
|
||||
clear_gpu_cache()
|
||||
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path=output_path)
|
||||
# copy original files to avoid parameters changes issues
|
||||
|
@ -196,6 +211,7 @@ with gr.Blocks() as demo:
|
|||
|
||||
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
|
||||
print("Model training done!")
|
||||
clear_gpu_cache()
|
||||
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue