Add erros messages

This commit is contained in:
Edresson Casanova 2023-11-27 10:41:09 -03:00
parent eaa5355c91
commit c5cb7eb791
1 changed files with 23 additions and 9 deletions

View File

@ -10,6 +10,7 @@ import numpy as np
import os
import torch
import torchaudio
import traceback
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
@ -22,11 +23,12 @@ def clear_gpu_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
clear_gpu_cache()
global XTTS_MODEL
clear_gpu_cache()
if not xtts_checkpoint or not xtts_config or not xtts_vocab:
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
config = XttsConfig()
config.load_json(xtts_config)
XTTS_MODEL = Xtts.init_from_config(config)
@ -39,6 +41,9 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
return "Model Loaded!"
def run_tts(lang, tts_text, speaker_audio_file):
if XTTS_MODEL is None or not speaker_audio_file:
return "You need to run the previous step to load the model !!", None, None
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
out = XTTS_MODEL.inference(
text=tts_text,
@ -57,7 +62,7 @@ def run_tts(lang, tts_text, speaker_audio_file):
out_path = fp.name
torchaudio.save(out_path, out["wav"], 24000)
return out_path, speaker_audio_file
return "Speech generated !", out_path, speaker_audio_file
@ -197,8 +202,7 @@ if __name__ == "__main__":
out_path = os.path.join(out_path, "dataset")
os.makedirs(out_path, exist_ok=True)
if audio_path is None:
# ToDo: raise an error
pass
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", ""
else:
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
@ -208,7 +212,7 @@ if __name__ == "__main__":
if audio_total_size < 120:
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
print(message)
return message, " ", " "
return message, "", ""
print("Dataset Processed!")
return "Dataset Processed!", train_meta, eval_meta
@ -253,8 +257,14 @@ if __name__ == "__main__":
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path):
clear_gpu_cache()
if not train_csv or not eval_csv:
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
try:
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path)
except Exception as e:
traceback.print_exc()
return f"The training was interrupted due an error !! Please check the console to check the error message! Error summary: {e}", "", "", "", ""
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path)
# copy original files to avoid parameters changes issues
os.system(f"cp {config_path} {exp_path}")
os.system(f"cp {vocab_file} {exp_path}")
@ -276,8 +286,9 @@ if __name__ == "__main__":
label="XTTS config path:",
value="",
)
xtts_vocab = gr.Textbox(
label="XTTS config path:",
label="XTTS vocab path:",
value="",
)
progress_load = gr.Label(
@ -319,6 +330,9 @@ if __name__ == "__main__":
tts_btn = gr.Button(value="Step 4 - Inference")
with gr.Column() as col3:
progress_gen = gr.Label(
label="Progress:"
)
tts_output_audio = gr.Audio(label="Generated Audio.")
reference_audio = gr.Audio(label="Reference audio used.")
@ -368,7 +382,7 @@ if __name__ == "__main__":
tts_text,
speaker_reference_audio,
],
outputs=[tts_output_audio, reference_audio],
outputs=[progress_gen, tts_output_audio, reference_audio],
)
demo.launch(