mirror of https://github.com/coqui-ai/TTS.git
Add XTTS FT demo data processing pipeline
This commit is contained in:
parent
29dede20d3
commit
774c4c1743
|
@ -0,0 +1,151 @@
|
||||||
|
import os
|
||||||
|
import torchaudio
|
||||||
|
import pandas
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
|
||||||
|
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
|
||||||
|
# torch.set_num_threads(1)
|
||||||
|
|
||||||
|
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
|
||||||
|
|
||||||
|
torch.set_num_threads(16)
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
audio_types = (".wav", ".mp3", ".flac")
|
||||||
|
|
||||||
|
|
||||||
|
def list_audios(basePath, contains=None):
|
||||||
|
# return the set of files that are valid
|
||||||
|
return list_files(basePath, validExts=audio_types, contains=contains)
|
||||||
|
|
||||||
|
def list_files(basePath, validExts=None, contains=None):
|
||||||
|
# loop over the directory structure
|
||||||
|
for (rootDir, dirNames, filenames) in os.walk(basePath):
|
||||||
|
# loop over the filenames in the current directory
|
||||||
|
for filename in filenames:
|
||||||
|
# if the contains string is not none and the filename does not contain
|
||||||
|
# the supplied string, then ignore the file
|
||||||
|
if contains is not None and filename.find(contains) == -1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# determine the file extension of the current file
|
||||||
|
ext = filename[filename.rfind("."):].lower()
|
||||||
|
|
||||||
|
# check to see if the file is an audio and should be processed
|
||||||
|
if validExts is None or ext.endswith(validExts):
|
||||||
|
# construct the path to the audio and yield it
|
||||||
|
audioPath = os.path.join(rootDir, filename)
|
||||||
|
yield audioPath
|
||||||
|
|
||||||
|
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.5, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
||||||
|
# make sure that ooutput file exists
|
||||||
|
os.makedirs(out_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Loading Whisper
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
print("Loading Whisper Model!")
|
||||||
|
asr_model = WhisperModel("large-v2", device=device, compute_type="float16")
|
||||||
|
|
||||||
|
metadata = {"audio_file": [], "text": [], "speaker_name": []}
|
||||||
|
|
||||||
|
if gradio_progress is not None:
|
||||||
|
tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...")
|
||||||
|
else:
|
||||||
|
tqdm_object = tqdm(audio_files)
|
||||||
|
|
||||||
|
for audio_path in tqdm_object:
|
||||||
|
wav, sr = torchaudio.load(audio_path)
|
||||||
|
wav = wav.squeeze()
|
||||||
|
segments, info = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
||||||
|
segments = list(segments)
|
||||||
|
i = 0
|
||||||
|
sentence = ""
|
||||||
|
sentence_start = None
|
||||||
|
first_word = True
|
||||||
|
# added all segments words in a unique list
|
||||||
|
words_list = []
|
||||||
|
for _, segment in enumerate(segments):
|
||||||
|
words = list(segment.words)
|
||||||
|
words_list.extend(words)
|
||||||
|
|
||||||
|
# process each word
|
||||||
|
for word_idx, word in enumerate(words_list):
|
||||||
|
if first_word:
|
||||||
|
sentence_start = word.start
|
||||||
|
# If it is the first sentence, add buffer or get the begining of the file
|
||||||
|
if word_idx == 0:
|
||||||
|
sentence_start = max(sentence_start - buffer, 0) # Add buffer to the sentence start
|
||||||
|
else:
|
||||||
|
# get previous sentence end
|
||||||
|
previous_word_end = words_list[word_idx - 1].end
|
||||||
|
# add buffer or get the silence midle between the previous sentence and the current one
|
||||||
|
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2)
|
||||||
|
|
||||||
|
sentence = word.word
|
||||||
|
first_word = False
|
||||||
|
else:
|
||||||
|
sentence += word.word
|
||||||
|
|
||||||
|
if word.word[-1] in ["!", ".", "?"]:
|
||||||
|
sentence = sentence[1:]
|
||||||
|
# Expand number and abbreviations plus normalization
|
||||||
|
sentence = multilingual_cleaners(sentence, target_language)
|
||||||
|
audio_file_name, ext = os.path.splitext(os.path.basename(audio_path))
|
||||||
|
|
||||||
|
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}{ext}"
|
||||||
|
|
||||||
|
# Check for the next word's existence
|
||||||
|
if word_idx + 1 < len(words_list):
|
||||||
|
next_word_start = words_list[word_idx + 1].start
|
||||||
|
else:
|
||||||
|
# If don't have more words it means that it is the last sentence then use the audio len as next word start
|
||||||
|
next_word_start = (wav.shape[0] - 1) / sr
|
||||||
|
|
||||||
|
# Average the current word end and next word start
|
||||||
|
word_end = min((word.end + next_word_start) / 2, word.end + buffer)
|
||||||
|
|
||||||
|
absoulte_path = os.path.join(out_path, audio_file)
|
||||||
|
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
|
||||||
|
i += 1
|
||||||
|
first_word = True
|
||||||
|
|
||||||
|
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
|
||||||
|
# if the audio is too short ignore it (i.e < 0.33 seconds)
|
||||||
|
if audio.size(-1) >= sr/3:
|
||||||
|
torchaudio.backend.sox_io_backend.save(
|
||||||
|
absoulte_path,
|
||||||
|
audio,
|
||||||
|
sr
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata["audio_file"].append(audio_file)
|
||||||
|
metadata["text"].append(sentence)
|
||||||
|
metadata["speaker_name"].append(speaker_name)
|
||||||
|
|
||||||
|
df = pandas.DataFrame(metadata)
|
||||||
|
df = df.sample(frac=1)
|
||||||
|
num_val_samples = int(len(df)*eval_percentage)
|
||||||
|
|
||||||
|
df_eval = df[:num_val_samples]
|
||||||
|
df_train = df[num_val_samples:]
|
||||||
|
|
||||||
|
df_train = df_train.sort_values('audio_file')
|
||||||
|
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
|
||||||
|
df_train.to_csv(train_metadata_path, sep="|", index=False)
|
||||||
|
|
||||||
|
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
|
||||||
|
df_eval = df_eval.sort_values('audio_file')
|
||||||
|
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
|
||||||
|
|
||||||
|
return train_metadata_path, eval_metadata_path
|
|
@ -0,0 +1,161 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import librosa.display
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list, list_audios
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
PORT = 5003
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_tts(lang, tts_text, state_vars, temperature, rms_norm_output=False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# define a logger to redirect
|
||||||
|
class Logger:
|
||||||
|
def __init__(self, filename="log.out"):
|
||||||
|
self.log_file = filename
|
||||||
|
self.terminal = sys.stdout
|
||||||
|
self.log = open(self.log_file, "w")
|
||||||
|
|
||||||
|
def write(self, message):
|
||||||
|
self.terminal.write(message)
|
||||||
|
self.log.write(message)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
self.terminal.flush()
|
||||||
|
self.log.flush()
|
||||||
|
|
||||||
|
def isatty(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# redirect stdout and stderr to a file
|
||||||
|
sys.stdout = Logger()
|
||||||
|
sys.stderr = sys.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def read_logs():
|
||||||
|
sys.stdout.flush()
|
||||||
|
with open(sys.stdout.log_file, "r") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
with gr.Tab("XTTS"):
|
||||||
|
state_vars = gr.State(
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column() as col1:
|
||||||
|
upload_file = gr.Audio(
|
||||||
|
sources="upload",
|
||||||
|
label="Select here the audio files that you want to use for XTTS trainining !",
|
||||||
|
type="filepath",
|
||||||
|
)
|
||||||
|
lang = gr.Dropdown(
|
||||||
|
label="Dataset Language",
|
||||||
|
value="en",
|
||||||
|
choices=[
|
||||||
|
"en",
|
||||||
|
"es",
|
||||||
|
"fr",
|
||||||
|
"de",
|
||||||
|
"it",
|
||||||
|
"pt",
|
||||||
|
"pl",
|
||||||
|
"tr",
|
||||||
|
"ru",
|
||||||
|
"nl",
|
||||||
|
"cs",
|
||||||
|
"ar",
|
||||||
|
"zh",
|
||||||
|
"hu",
|
||||||
|
"ko",
|
||||||
|
"ja"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
voice_ready = gr.Label(
|
||||||
|
label="Progress."
|
||||||
|
)
|
||||||
|
logs = gr.Textbox(
|
||||||
|
label="Logs:",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
demo.load(read_logs, None, logs, every=1)
|
||||||
|
|
||||||
|
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
|
||||||
|
|
||||||
|
with gr.Column() as col2:
|
||||||
|
|
||||||
|
tts_text = gr.Textbox(
|
||||||
|
label="Input Text.",
|
||||||
|
value="This model sounds really good and above all, it's reasonably fast.",
|
||||||
|
)
|
||||||
|
temperature = gr.Slider(
|
||||||
|
label="temperature", minimum=0.00001, maximum=1.0, step=0.05, value=0.75
|
||||||
|
)
|
||||||
|
rms_norm_output = gr.Checkbox(
|
||||||
|
label="RMS norm output.", value=True, interactive=True
|
||||||
|
)
|
||||||
|
tts_btn = gr.Button(value="Step 2 - TTS")
|
||||||
|
|
||||||
|
with gr.Column() as col3:
|
||||||
|
tts_output_audio_no_enhanced = gr.Audio(label="HiFi-GAN.")
|
||||||
|
tts_output_audio_no_enhanced_ft = gr.Audio(label="HiFi-GAN new.")
|
||||||
|
reference_audio = gr.Audio(label="Reference Speech used.")
|
||||||
|
|
||||||
|
def preprocess_dataset(audio_path, language, state_vars, progress=gr.Progress(track_tqdm=True)):
|
||||||
|
# create a temp directory to save the dataset
|
||||||
|
out_path = tempfile.TemporaryDirectory().name
|
||||||
|
if audio_path is None:
|
||||||
|
# ToDo: raise an error
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
|
||||||
|
train_meta, eval_meta = format_audio_list([audio_path], target_language=language, out_path=out_path, gradio_progress=progress)
|
||||||
|
|
||||||
|
state_vars = {}
|
||||||
|
state_vars["train_csv"] = train_meta
|
||||||
|
state_vars["eval_csv"] = eval_meta
|
||||||
|
return "Dataset Processed!", state_vars
|
||||||
|
|
||||||
|
prompt_compute_btn.click(
|
||||||
|
fn=preprocess_dataset,
|
||||||
|
inputs=[
|
||||||
|
upload_file,
|
||||||
|
lang,
|
||||||
|
state_vars,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
voice_ready,
|
||||||
|
state_vars,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tts_btn.click(
|
||||||
|
fn=run_tts,
|
||||||
|
inputs=[
|
||||||
|
lang,
|
||||||
|
tts_text,
|
||||||
|
state_vars,
|
||||||
|
temperature,
|
||||||
|
rms_norm_output,
|
||||||
|
],
|
||||||
|
outputs=[tts_output_audio_no_enhanced, tts_output_audio_no_enhanced_ft],
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo.launch(
|
||||||
|
share=True,
|
||||||
|
debug=True,
|
||||||
|
server_port=PORT,
|
||||||
|
server_name="0.0.0.0"
|
||||||
|
)
|
Loading…
Reference in New Issue