mirror of https://github.com/coqui-ai/TTS.git
Update gradio demo
This commit is contained in:
parent
af74cd4426
commit
8967fc7ef2
|
@ -44,6 +44,7 @@ def list_files(basePath, validExts=None, contains=None):
|
||||||
yield audioPath
|
yield audioPath
|
||||||
|
|
||||||
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
||||||
|
audio_total_size = 0
|
||||||
# make sure that ooutput file exists
|
# make sure that ooutput file exists
|
||||||
os.makedirs(out_path, exist_ok=True)
|
os.makedirs(out_path, exist_ok=True)
|
||||||
|
|
||||||
|
@ -67,7 +68,9 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
wav = torch.mean(wav, dim=0, keepdim=True)
|
wav = torch.mean(wav, dim=0, keepdim=True)
|
||||||
|
|
||||||
wav = wav.squeeze()
|
wav = wav.squeeze()
|
||||||
segments, info = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
audio_total_size += (wav.size(-1) / sr)
|
||||||
|
|
||||||
|
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
||||||
segments = list(segments)
|
segments = list(segments)
|
||||||
i = 0
|
i = 0
|
||||||
sentence = ""
|
sentence = ""
|
||||||
|
@ -101,9 +104,9 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
sentence = sentence[1:]
|
sentence = sentence[1:]
|
||||||
# Expand number and abbreviations plus normalization
|
# Expand number and abbreviations plus normalization
|
||||||
sentence = multilingual_cleaners(sentence, target_language)
|
sentence = multilingual_cleaners(sentence, target_language)
|
||||||
audio_file_name, ext = os.path.splitext(os.path.basename(audio_path))
|
audio_file_name, _ = os.path.splitext(os.path.basename(audio_path))
|
||||||
|
|
||||||
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}{ext}"
|
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav"
|
||||||
|
|
||||||
# Check for the next word's existence
|
# Check for the next word's existence
|
||||||
if word_idx + 1 < len(words_list):
|
if word_idx + 1 < len(words_list):
|
||||||
|
@ -125,8 +128,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
if audio.size(-1) >= sr/3:
|
if audio.size(-1) >= sr/3:
|
||||||
torchaudio.save(absoulte_path,
|
torchaudio.save(absoulte_path,
|
||||||
audio,
|
audio,
|
||||||
sr,
|
sr
|
||||||
backend="sox",
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
@ -150,4 +152,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
df_eval = df_eval.sort_values('audio_file')
|
df_eval = df_eval.sort_values('audio_file')
|
||||||
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
|
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
|
||||||
|
|
||||||
return train_metadata_path, eval_metadata_path
|
# deallocate VRAM
|
||||||
|
del asr_model
|
||||||
|
|
||||||
|
return train_metadata_path, eval_metadata_path, audio_total_size
|
|
@ -164,4 +164,7 @@ def train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path
|
||||||
longest_text_idx = samples_len.index(max(samples_len))
|
longest_text_idx = samples_len.index(max(samples_len))
|
||||||
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
||||||
|
|
||||||
|
# deallocate VRAM
|
||||||
|
del model, trainer
|
||||||
|
|
||||||
return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, speaker_ref
|
return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, speaker_ref
|
||||||
|
|
|
@ -9,17 +9,23 @@ import numpy as np
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list, list_audios
|
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
|
||||||
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
|
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
|
||||||
|
|
||||||
from TTS.tts.configs.xtts_config import XttsConfig
|
from TTS.tts.configs.xtts_config import XttsConfig
|
||||||
from TTS.tts.models.xtts import Xtts
|
from TTS.tts.models.xtts import Xtts
|
||||||
|
|
||||||
|
|
||||||
|
def clear_gpu_cache():
|
||||||
|
# clear the GPU cache
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
PORT = 5003
|
PORT = 5003
|
||||||
|
|
||||||
XTTS_MODEL = None
|
XTTS_MODEL = None
|
||||||
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||||
|
clear_gpu_cache()
|
||||||
global XTTS_MODEL
|
global XTTS_MODEL
|
||||||
config = XttsConfig()
|
config = XttsConfig()
|
||||||
config.load_json(xtts_config)
|
config.load_json(xtts_config)
|
||||||
|
@ -144,13 +150,23 @@ with gr.Blocks() as demo:
|
||||||
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
|
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
|
||||||
|
|
||||||
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
|
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
|
||||||
|
clear_gpu_cache()
|
||||||
out_path = os.path.join(out_path, "dataset")
|
out_path = os.path.join(out_path, "dataset")
|
||||||
os.makedirs(out_path, exist_ok=True)
|
os.makedirs(out_path, exist_ok=True)
|
||||||
if audio_path is None:
|
if audio_path is None:
|
||||||
# ToDo: raise an error
|
# ToDo: raise an error
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
train_meta, eval_meta = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
||||||
|
|
||||||
|
clear_gpu_cache()
|
||||||
|
|
||||||
|
# if audio total len is less than 2 minutes raise an error
|
||||||
|
if audio_total_size < 120:
|
||||||
|
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
|
||||||
|
print(message)
|
||||||
|
return message, " ", " "
|
||||||
|
|
||||||
print("Dataset Processed!")
|
print("Dataset Processed!")
|
||||||
return "Dataset Processed!", train_meta, eval_meta
|
return "Dataset Processed!", train_meta, eval_meta
|
||||||
|
|
||||||
|
@ -173,7 +189,7 @@ with gr.Blocks() as demo:
|
||||||
minimum=2,
|
minimum=2,
|
||||||
maximum=512,
|
maximum=512,
|
||||||
step=1,
|
step=1,
|
||||||
value=15,
|
value=16,
|
||||||
)
|
)
|
||||||
progress_train = gr.Label(
|
progress_train = gr.Label(
|
||||||
label="Progress:"
|
label="Progress:"
|
||||||
|
@ -186,8 +202,7 @@ with gr.Blocks() as demo:
|
||||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||||
|
|
||||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path, progress=gr.Progress(track_tqdm=True)):
|
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path, progress=gr.Progress(track_tqdm=True)):
|
||||||
# train_csv = '/tmp/tmprh4k_vou/metadata_train.csv'
|
clear_gpu_cache()
|
||||||
# eval_csv = '/tmp/tmprh4k_vou/metadata_eval.csv'
|
|
||||||
|
|
||||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path=output_path)
|
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path=output_path)
|
||||||
# copy original files to avoid parameters changes issues
|
# copy original files to avoid parameters changes issues
|
||||||
|
@ -196,6 +211,7 @@ with gr.Blocks() as demo:
|
||||||
|
|
||||||
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
|
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
|
||||||
print("Model training done!")
|
print("Model training done!")
|
||||||
|
clear_gpu_cache()
|
||||||
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue