From fa9bb26ebb2cb4ecb1e37a3b4fad608ee9ddc96c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 24 Nov 2023 10:22:12 -0300 Subject: [PATCH] Update demo --- TTS/demos/xtts_ft_demo/utils/formatter.py | 10 +- TTS/demos/xtts_ft_demo/utils/gpt_train.py | 6 +- TTS/demos/xtts_ft_demo/xtts_demo.py | 186 ++++++++++++---------- 3 files changed, 112 insertions(+), 90 deletions(-) diff --git a/TTS/demos/xtts_ft_demo/utils/formatter.py b/TTS/demos/xtts_ft_demo/utils/formatter.py index 95bb0f1b..03db6c2c 100644 --- a/TTS/demos/xtts_ft_demo/utils/formatter.py +++ b/TTS/demos/xtts_ft_demo/utils/formatter.py @@ -8,8 +8,6 @@ from tqdm import tqdm import torch import torchaudio -from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load -from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load # torch.set_num_threads(1) from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners @@ -45,7 +43,7 @@ def list_files(basePath, validExts=None, contains=None): audioPath = os.path.join(rootDir, filename) yield audioPath -def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.5, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): +def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): # make sure that ooutput file exists os.makedirs(out_path, exist_ok=True) @@ -121,10 +119,10 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0) # if the audio is too short ignore it (i.e < 0.33 seconds) if audio.size(-1) >= sr/3: - torchaudio.backend.sox_io_backend.save( - absoulte_path, + torchaudio.save(absoulte_path, audio, - sr + sr, + backend="sox", ) else: continue diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py index a4f5cb9a..1e7d5f36 100644 --- a/TTS/demos/xtts_ft_demo/utils/gpt_train.py +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -159,5 +159,9 @@ def train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path ) trainer.fit() + # get the longest text audio file to use as speaker reference + samples_len = [len(item["text"].split(" ")) for item in train_samples] + longest_text_idx = samples_len.index(max(samples_len)) + speaker_ref = train_samples[longest_text_idx]["audio_file"] - return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, train_samples[0]["audio_file"] + return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, speaker_ref diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py index 6fee1a50..9dcaefce 100644 --- a/TTS/demos/xtts_ft_demo/xtts_demo.py +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -18,31 +18,32 @@ from TTS.tts.models.xtts import Xtts PORT = 5003 +XTTS_MODEL = None def load_model(xtts_checkpoint, xtts_config, xtts_vocab): + global XTTS_MODEL config = XttsConfig() config.load_json(xtts_config) - model = Xtts.init_from_config(config) + XTTS_MODEL = Xtts.init_from_config(config) print("Loading XTTS model! ") - model.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False) + XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False) if torch.cuda.is_available(): - model.cuda() - return model + XTTS_MODEL.cuda() -def run_tts(lang, tts_text, xtts_checkpoint, xtts_config, xtts_vocab, speaker_audio_file): - # ToDo: add the load in other function to fast inference - model = load_model(xtts_checkpoint, xtts_config, xtts_vocab) - gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=model.config.gpt_cond_len, max_ref_length=model.config.max_ref_len, sound_norm_refs=model.config.sound_norm_refs) - speaker_embedding - out = model.inference( + print("Model Loaded!") + return "Model Loaded!" + +def run_tts(lang, tts_text, speaker_audio_file): + gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) + out = XTTS_MODEL.inference( text=tts_text, language=lang, gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, - temperature=model.config.temperature, # Add custom parameters here - length_penalty=model.config.length_penalty, - repetition_penalty=model.config.repetition_penalty, - top_k=model.config.top_k, - top_p=model.config.top_p, + temperature=XTTS_MODEL.config.temperature, # Add custom parameters here + length_penalty=XTTS_MODEL.config.length_penalty, + repetition_penalty=XTTS_MODEL.config.repetition_penalty, + top_k=XTTS_MODEL.config.top_k, + top_p=XTTS_MODEL.config.top_p, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: @@ -95,12 +96,19 @@ def read_logs(): with gr.Blocks() as demo: - # state_vars = gr.State() - with gr.Tab("Data processing"): - upload_file = gr.Audio( - sources="upload", - label="Select here the audio files that you want to use for XTTS trainining !", - type="filepath", + with gr.Tab("Data processing"): + out_path = gr.Textbox( + label="Output path (where data and checkpoints will be saved):", + value="/tmp/xtts_ft/" + ) + # upload_file = gr.Audio( + # sources="upload", + # label="Select here the audio files that you want to use for XTTS trainining !", + # type="filepath", + # ) + upload_file = gr.File( + file_count="multiple", + label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)", ) lang = gr.Dropdown( label="Dataset Language", @@ -135,18 +143,18 @@ with gr.Blocks() as demo: prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.") - def preprocess_dataset(audio_path, language, progress=gr.Progress(track_tqdm=True)): - # create a temp directory to save the dataset - out_path = tempfile.TemporaryDirectory().name + def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): + out_path = os.path.join(out_path, "dataset") + os.makedirs(out_path, exist_ok=True) if audio_path is None: # ToDo: raise an error pass else: - train_meta, eval_meta = format_audio_list([audio_path], target_language=language, out_path=out_path, gradio_progress=progress) - + train_meta, eval_meta = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) + print("Dataset Processed!") return "Dataset Processed!", train_meta, eval_meta - with gr.Tab("Fine-tuning XTTS"): + with gr.Tab("Fine-tuning XTTS Encoder"): train_csv = gr.Textbox( label="Train CSV:", ) @@ -158,7 +166,7 @@ with gr.Blocks() as demo: minimum=1, maximum=100, step=1, - value=2,# 15 + value=10, ) batch_size = gr.Slider( label="batch_size", @@ -177,7 +185,7 @@ with gr.Blocks() as demo: demo.load(read_logs, None, logs_tts_train, every=1) train_btn = gr.Button(value="Step 2 - Run the training") - def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path="./", progress=gr.Progress(track_tqdm=True)): + def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path, progress=gr.Progress(track_tqdm=True)): # train_csv = '/tmp/tmprh4k_vou/metadata_train.csv' # eval_csv = '/tmp/tmprh4k_vou/metadata_eval.csv' @@ -187,67 +195,73 @@ with gr.Blocks() as demo: os.system(f"cp {vocab_file} {exp_path}") ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") - # state_vars["config_path"] = config_path - # state_vars["original_xtts_checkpoint"] = original_xtts_checkpoint - # state_vars["vocab_file"] = vocab_file - # state_vars["ft_xtts_checkpoint"] = ft_xtts_checkpoint - # state_vars["speaker_audio_file"] = speaker_wav + print("Model training done!") return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav with gr.Tab("Inference"): - xtts_checkpoint = gr.Textbox( - label="XTTS checkpoint path:", - value="", - ) - xtts_config = gr.Textbox( - label="XTTS config path:", - value="", - ) - xtts_vocab = gr.Textbox( - label="XTTS config path:", - value="", - ) - speaker_reference_audio = gr.Textbox( - label="Speaker reference audio:", - value="", - ) - tts_language = gr.Dropdown( - label="Language", - value="en", - choices=[ - "en", - "es", - "fr", - "de", - "it", - "pt", - "pl", - "tr", - "ru", - "nl", - "cs", - "ar", - "zh", - "hu", - "ko", - "ja", - ] - ) - tts_text = gr.Textbox( - label="Input Text.", - value="This model sounds really good and above all, it's reasonably fast.", - ) - tts_btn = gr.Button(value="Step 3 - Inference XTTS model") + with gr.Row(): + with gr.Column() as col1: + xtts_checkpoint = gr.Textbox( + label="XTTS checkpoint path:", + value="", + ) + xtts_config = gr.Textbox( + label="XTTS config path:", + value="", + ) + xtts_vocab = gr.Textbox( + label="XTTS config path:", + value="", + ) + progress_load = gr.Label( + label="Progress:" + ) + load_btn = gr.Button(value="Step 3 - Load Fine tuned XTTS model") - tts_output_audio = gr.Audio(label="Generated Audio.") - reference_audio = gr.Audio(label="Reference audio used.") + with gr.Column() as col2: + speaker_reference_audio = gr.Textbox( + label="Speaker reference audio:", + value="", + ) + tts_language = gr.Dropdown( + label="Language", + value="en", + choices=[ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh", + "hu", + "ko", + "ja", + ] + ) + tts_text = gr.Textbox( + label="Input Text.", + value="This model sounds really good and above all, it's reasonably fast.", + ) + tts_btn = gr.Button(value="Step 4 - Inference") + + with gr.Column() as col3: + tts_output_audio = gr.Audio(label="Generated Audio.") + reference_audio = gr.Audio(label="Reference audio used.") prompt_compute_btn.click( fn=preprocess_dataset, inputs=[ upload_file, lang, + out_path, ], outputs=[ progress_data, @@ -255,7 +269,6 @@ with gr.Blocks() as demo: eval_csv, ], ) - train_btn.click( @@ -266,19 +279,26 @@ with gr.Blocks() as demo: eval_csv, num_epochs, batch_size, + out_path, ], outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], ) + load_btn.click( + fn=load_model, + inputs=[ + xtts_checkpoint, + xtts_config, + xtts_vocab + ], + outputs=[progress_load], + ) tts_btn.click( fn=run_tts, inputs=[ tts_language, tts_text, - xtts_checkpoint, - xtts_config, - xtts_vocab, speaker_reference_audio, ], outputs=[tts_output_audio, reference_audio],