mirror of https://github.com/coqui-ai/TTS.git
Add erros messages
This commit is contained in:
parent
eaa5355c91
commit
c5cb7eb791
|
@ -10,6 +10,7 @@ import numpy as np
|
|||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
import traceback
|
||||
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
|
||||
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
|
||||
|
||||
|
@ -22,11 +23,12 @@ def clear_gpu_cache():
|
|||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
XTTS_MODEL = None
|
||||
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||
clear_gpu_cache()
|
||||
global XTTS_MODEL
|
||||
clear_gpu_cache()
|
||||
if not xtts_checkpoint or not xtts_config or not xtts_vocab:
|
||||
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
|
||||
config = XttsConfig()
|
||||
config.load_json(xtts_config)
|
||||
XTTS_MODEL = Xtts.init_from_config(config)
|
||||
|
@ -39,6 +41,9 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
|||
return "Model Loaded!"
|
||||
|
||||
def run_tts(lang, tts_text, speaker_audio_file):
|
||||
if XTTS_MODEL is None or not speaker_audio_file:
|
||||
return "You need to run the previous step to load the model !!", None, None
|
||||
|
||||
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
|
||||
out = XTTS_MODEL.inference(
|
||||
text=tts_text,
|
||||
|
@ -57,7 +62,7 @@ def run_tts(lang, tts_text, speaker_audio_file):
|
|||
out_path = fp.name
|
||||
torchaudio.save(out_path, out["wav"], 24000)
|
||||
|
||||
return out_path, speaker_audio_file
|
||||
return "Speech generated !", out_path, speaker_audio_file
|
||||
|
||||
|
||||
|
||||
|
@ -197,8 +202,7 @@ if __name__ == "__main__":
|
|||
out_path = os.path.join(out_path, "dataset")
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
if audio_path is None:
|
||||
# ToDo: raise an error
|
||||
pass
|
||||
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", ""
|
||||
else:
|
||||
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
||||
|
||||
|
@ -208,7 +212,7 @@ if __name__ == "__main__":
|
|||
if audio_total_size < 120:
|
||||
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
|
||||
print(message)
|
||||
return message, " ", " "
|
||||
return message, "", ""
|
||||
|
||||
print("Dataset Processed!")
|
||||
return "Dataset Processed!", train_meta, eval_meta
|
||||
|
@ -253,8 +257,14 @@ if __name__ == "__main__":
|
|||
|
||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path):
|
||||
clear_gpu_cache()
|
||||
if not train_csv or not eval_csv:
|
||||
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
|
||||
try:
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return f"The training was interrupted due an error !! Please check the console to check the error message! Error summary: {e}", "", "", "", ""
|
||||
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path)
|
||||
# copy original files to avoid parameters changes issues
|
||||
os.system(f"cp {config_path} {exp_path}")
|
||||
os.system(f"cp {vocab_file} {exp_path}")
|
||||
|
@ -276,8 +286,9 @@ if __name__ == "__main__":
|
|||
label="XTTS config path:",
|
||||
value="",
|
||||
)
|
||||
|
||||
xtts_vocab = gr.Textbox(
|
||||
label="XTTS config path:",
|
||||
label="XTTS vocab path:",
|
||||
value="",
|
||||
)
|
||||
progress_load = gr.Label(
|
||||
|
@ -319,6 +330,9 @@ if __name__ == "__main__":
|
|||
tts_btn = gr.Button(value="Step 4 - Inference")
|
||||
|
||||
with gr.Column() as col3:
|
||||
progress_gen = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
tts_output_audio = gr.Audio(label="Generated Audio.")
|
||||
reference_audio = gr.Audio(label="Reference audio used.")
|
||||
|
||||
|
@ -368,7 +382,7 @@ if __name__ == "__main__":
|
|||
tts_text,
|
||||
speaker_reference_audio,
|
||||
],
|
||||
outputs=[tts_output_audio, reference_audio],
|
||||
outputs=[progress_gen, tts_output_audio, reference_audio],
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
|
|
Loading…
Reference in New Issue