From 774c4c1743506186333c9b8b40a3cbb362c36b53 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 22 Nov 2023 18:11:52 -0300 Subject: [PATCH] Add XTTS FT demo data processing pipeline --- TTS/demos/xtts_ft_demo/utils/formatter.py | 151 ++++++++++++++++++++ TTS/demos/xtts_ft_demo/xtts_demo.py | 161 ++++++++++++++++++++++ 2 files changed, 312 insertions(+) create mode 100644 TTS/demos/xtts_ft_demo/utils/formatter.py create mode 100644 TTS/demos/xtts_ft_demo/xtts_demo.py diff --git a/TTS/demos/xtts_ft_demo/utils/formatter.py b/TTS/demos/xtts_ft_demo/utils/formatter.py new file mode 100644 index 00000000..95bb0f1b --- /dev/null +++ b/TTS/demos/xtts_ft_demo/utils/formatter.py @@ -0,0 +1,151 @@ +import os +import torchaudio +import pandas +from faster_whisper import WhisperModel +from glob import glob + +from tqdm import tqdm + +import torch +import torchaudio +from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load +from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load +# torch.set_num_threads(1) + +from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners + +torch.set_num_threads(16) + + +import os + +audio_types = (".wav", ".mp3", ".flac") + + +def list_audios(basePath, contains=None): + # return the set of files that are valid + return list_files(basePath, validExts=audio_types, contains=contains) + +def list_files(basePath, validExts=None, contains=None): + # loop over the directory structure + for (rootDir, dirNames, filenames) in os.walk(basePath): + # loop over the filenames in the current directory + for filename in filenames: + # if the contains string is not none and the filename does not contain + # the supplied string, then ignore the file + if contains is not None and filename.find(contains) == -1: + continue + + # determine the file extension of the current file + ext = filename[filename.rfind("."):].lower() + + # check to see if the file is an audio and should be processed + if validExts is None or ext.endswith(validExts): + # construct the path to the audio and yield it + audioPath = os.path.join(rootDir, filename) + yield audioPath + +def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.5, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): + # make sure that ooutput file exists + os.makedirs(out_path, exist_ok=True) + + # Loading Whisper + device = "cuda" if torch.cuda.is_available() else "cpu" + + print("Loading Whisper Model!") + asr_model = WhisperModel("large-v2", device=device, compute_type="float16") + + metadata = {"audio_file": [], "text": [], "speaker_name": []} + + if gradio_progress is not None: + tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...") + else: + tqdm_object = tqdm(audio_files) + + for audio_path in tqdm_object: + wav, sr = torchaudio.load(audio_path) + wav = wav.squeeze() + segments, info = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) + segments = list(segments) + i = 0 + sentence = "" + sentence_start = None + first_word = True + # added all segments words in a unique list + words_list = [] + for _, segment in enumerate(segments): + words = list(segment.words) + words_list.extend(words) + + # process each word + for word_idx, word in enumerate(words_list): + if first_word: + sentence_start = word.start + # If it is the first sentence, add buffer or get the begining of the file + if word_idx == 0: + sentence_start = max(sentence_start - buffer, 0) # Add buffer to the sentence start + else: + # get previous sentence end + previous_word_end = words_list[word_idx - 1].end + # add buffer or get the silence midle between the previous sentence and the current one + sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2) + + sentence = word.word + first_word = False + else: + sentence += word.word + + if word.word[-1] in ["!", ".", "?"]: + sentence = sentence[1:] + # Expand number and abbreviations plus normalization + sentence = multilingual_cleaners(sentence, target_language) + audio_file_name, ext = os.path.splitext(os.path.basename(audio_path)) + + audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}{ext}" + + # Check for the next word's existence + if word_idx + 1 < len(words_list): + next_word_start = words_list[word_idx + 1].start + else: + # If don't have more words it means that it is the last sentence then use the audio len as next word start + next_word_start = (wav.shape[0] - 1) / sr + + # Average the current word end and next word start + word_end = min((word.end + next_word_start) / 2, word.end + buffer) + + absoulte_path = os.path.join(out_path, audio_file) + os.makedirs(os.path.dirname(absoulte_path), exist_ok=True) + i += 1 + first_word = True + + audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0) + # if the audio is too short ignore it (i.e < 0.33 seconds) + if audio.size(-1) >= sr/3: + torchaudio.backend.sox_io_backend.save( + absoulte_path, + audio, + sr + ) + else: + continue + + metadata["audio_file"].append(audio_file) + metadata["text"].append(sentence) + metadata["speaker_name"].append(speaker_name) + + df = pandas.DataFrame(metadata) + df = df.sample(frac=1) + num_val_samples = int(len(df)*eval_percentage) + + df_eval = df[:num_val_samples] + df_train = df[num_val_samples:] + + df_train = df_train.sort_values('audio_file') + train_metadata_path = os.path.join(out_path, "metadata_train.csv") + df_train.to_csv(train_metadata_path, sep="|", index=False) + + eval_metadata_path = os.path.join(out_path, "metadata_eval.csv") + df_eval = df_eval.sort_values('audio_file') + df_eval.to_csv(eval_metadata_path, sep="|", index=False) + + return train_metadata_path, eval_metadata_path \ No newline at end of file diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py new file mode 100644 index 00000000..99b64792 --- /dev/null +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -0,0 +1,161 @@ +import os +import sys +import tempfile + +import gradio as gr +import librosa.display +import numpy as np + +import os +import torch +import torchaudio +from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list, list_audios + +import logging + +PORT = 5003 + + + +def run_tts(lang, tts_text, state_vars, temperature, rms_norm_output=False): + return None + +# define a logger to redirect +class Logger: + def __init__(self, filename="log.out"): + self.log_file = filename + self.terminal = sys.stdout + self.log = open(self.log_file, "w") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + self.terminal.flush() + self.log.flush() + + def isatty(self): + return False + +# redirect stdout and stderr to a file +sys.stdout = Logger() +sys.stderr = sys.stdout + + +def read_logs(): + sys.stdout.flush() + with open(sys.stdout.log_file, "r") as f: + return f.read() + + +with gr.Blocks() as demo: + with gr.Tab("XTTS"): + state_vars = gr.State( + ) + with gr.Row(): + with gr.Column() as col1: + upload_file = gr.Audio( + sources="upload", + label="Select here the audio files that you want to use for XTTS trainining !", + type="filepath", + ) + lang = gr.Dropdown( + label="Dataset Language", + value="en", + choices=[ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh", + "hu", + "ko", + "ja" + ], + ) + voice_ready = gr.Label( + label="Progress." + ) + logs = gr.Textbox( + label="Logs:", + interactive=False, + ) + demo.load(read_logs, None, logs, every=1) + + prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.") + + with gr.Column() as col2: + + tts_text = gr.Textbox( + label="Input Text.", + value="This model sounds really good and above all, it's reasonably fast.", + ) + temperature = gr.Slider( + label="temperature", minimum=0.00001, maximum=1.0, step=0.05, value=0.75 + ) + rms_norm_output = gr.Checkbox( + label="RMS norm output.", value=True, interactive=True + ) + tts_btn = gr.Button(value="Step 2 - TTS") + + with gr.Column() as col3: + tts_output_audio_no_enhanced = gr.Audio(label="HiFi-GAN.") + tts_output_audio_no_enhanced_ft = gr.Audio(label="HiFi-GAN new.") + reference_audio = gr.Audio(label="Reference Speech used.") + + def preprocess_dataset(audio_path, language, state_vars, progress=gr.Progress(track_tqdm=True)): + # create a temp directory to save the dataset + out_path = tempfile.TemporaryDirectory().name + if audio_path is None: + # ToDo: raise an error + pass + else: + + train_meta, eval_meta = format_audio_list([audio_path], target_language=language, out_path=out_path, gradio_progress=progress) + + state_vars = {} + state_vars["train_csv"] = train_meta + state_vars["eval_csv"] = eval_meta + return "Dataset Processed!", state_vars + + prompt_compute_btn.click( + fn=preprocess_dataset, + inputs=[ + upload_file, + lang, + state_vars, + ], + outputs=[ + voice_ready, + state_vars, + ], + ) + + tts_btn.click( + fn=run_tts, + inputs=[ + lang, + tts_text, + state_vars, + temperature, + rms_norm_output, + ], + outputs=[tts_output_audio_no_enhanced, tts_output_audio_no_enhanced_ft], + ) + +if __name__ == "__main__": + demo.launch( + share=True, + debug=True, + server_port=PORT, + server_name="0.0.0.0" + )