mirror of https://github.com/coqui-ai/TTS.git
Add parameters to be able to set then on colab demo
This commit is contained in:
parent
335b8c37b3
commit
eaa5355c91
|
@ -8,9 +8,9 @@ from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrai
|
|||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
def train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path):
|
||||
def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path):
|
||||
# Logging parameters
|
||||
RUN_NAME = "GPT_XTTSv2.1_FT"
|
||||
RUN_NAME = "GPT_XTTS_FT"
|
||||
PROJECT_NAME = "XTTS_trainer"
|
||||
DASHBOARD_LOGGER = "tensorboard"
|
||||
LOGGER_URI = None
|
||||
|
@ -18,13 +18,11 @@ def train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path
|
|||
# Set here the path that the checkpoints will be saved. Default: ./run/training/
|
||||
OUT_PATH = os.path.join(output_path, "run", "training")
|
||||
|
||||
|
||||
# Training Parameters
|
||||
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
|
||||
START_WITH_EVAL = True # if True it will star with evaluation
|
||||
START_WITH_EVAL = False # if True it will star with evaluation
|
||||
BATCH_SIZE = batch_size # set here the batch size
|
||||
GRAD_ACUMM_STEPS = 1 # set here the grad accumulation steps
|
||||
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
|
||||
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
|
||||
|
||||
|
||||
# Define here the dataset that you want to use for the fine-tuning on.
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
@ -21,7 +22,6 @@ def clear_gpu_cache():
|
|||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
PORT = 5003
|
||||
|
||||
XTTS_MODEL = None
|
||||
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||
|
@ -101,231 +101,279 @@ def read_logs():
|
|||
return f.read()
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Tab("Data processing"):
|
||||
out_path = gr.Textbox(
|
||||
label="Output path (where data and checkpoints will be saved):",
|
||||
value="/tmp/xtts_ft/"
|
||||
)
|
||||
# upload_file = gr.Audio(
|
||||
# sources="upload",
|
||||
# label="Select here the audio files that you want to use for XTTS trainining !",
|
||||
# type="filepath",
|
||||
# )
|
||||
upload_file = gr.File(
|
||||
file_count="multiple",
|
||||
label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)",
|
||||
)
|
||||
lang = gr.Dropdown(
|
||||
label="Dataset Language",
|
||||
value="en",
|
||||
choices=[
|
||||
"en",
|
||||
"es",
|
||||
"fr",
|
||||
"de",
|
||||
"it",
|
||||
"pt",
|
||||
"pl",
|
||||
"tr",
|
||||
"ru",
|
||||
"nl",
|
||||
"cs",
|
||||
"ar",
|
||||
"zh",
|
||||
"hu",
|
||||
"ko",
|
||||
"ja"
|
||||
],
|
||||
)
|
||||
progress_data = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
logs = gr.Textbox(
|
||||
label="Logs:",
|
||||
interactive=False,
|
||||
)
|
||||
demo.load(read_logs, None, logs, every=1)
|
||||
|
||||
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
|
||||
|
||||
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
|
||||
clear_gpu_cache()
|
||||
out_path = os.path.join(out_path, "dataset")
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
if audio_path is None:
|
||||
# ToDo: raise an error
|
||||
pass
|
||||
else:
|
||||
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
||||
|
||||
clear_gpu_cache()
|
||||
|
||||
# if audio total len is less than 2 minutes raise an error
|
||||
if audio_total_size < 120:
|
||||
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
|
||||
print(message)
|
||||
return message, " ", " "
|
||||
|
||||
print("Dataset Processed!")
|
||||
return "Dataset Processed!", train_meta, eval_meta
|
||||
|
||||
with gr.Tab("Fine-tuning XTTS Encoder"):
|
||||
train_csv = gr.Textbox(
|
||||
label="Train CSV:",
|
||||
)
|
||||
eval_csv = gr.Textbox(
|
||||
label="Eval CSV:",
|
||||
)
|
||||
num_epochs = gr.Slider(
|
||||
label="num_epochs",
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
step=1,
|
||||
value=10,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
label="batch_size",
|
||||
minimum=2,
|
||||
maximum=512,
|
||||
step=1,
|
||||
value=4,
|
||||
)
|
||||
progress_train = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
logs_tts_train = gr.Textbox(
|
||||
label="Logs:",
|
||||
interactive=False,
|
||||
)
|
||||
demo.load(read_logs, None, logs_tts_train, every=1)
|
||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||
|
||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, output_path, progress=gr.Progress(track_tqdm=True)):
|
||||
clear_gpu_cache()
|
||||
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path=output_path)
|
||||
# copy original files to avoid parameters changes issues
|
||||
os.system(f"cp {config_path} {exp_path}")
|
||||
os.system(f"cp {vocab_file} {exp_path}")
|
||||
|
||||
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
|
||||
print("Model training done!")
|
||||
clear_gpu_cache()
|
||||
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
||||
|
||||
|
||||
with gr.Tab("Inference"):
|
||||
with gr.Row():
|
||||
with gr.Column() as col1:
|
||||
xtts_checkpoint = gr.Textbox(
|
||||
label="XTTS checkpoint path:",
|
||||
value="",
|
||||
)
|
||||
xtts_config = gr.Textbox(
|
||||
label="XTTS config path:",
|
||||
value="",
|
||||
)
|
||||
xtts_vocab = gr.Textbox(
|
||||
label="XTTS config path:",
|
||||
value="",
|
||||
)
|
||||
progress_load = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
load_btn = gr.Button(value="Step 3 - Load Fine tuned XTTS model")
|
||||
|
||||
with gr.Column() as col2:
|
||||
speaker_reference_audio = gr.Textbox(
|
||||
label="Speaker reference audio:",
|
||||
value="",
|
||||
)
|
||||
tts_language = gr.Dropdown(
|
||||
label="Language",
|
||||
value="en",
|
||||
choices=[
|
||||
"en",
|
||||
"es",
|
||||
"fr",
|
||||
"de",
|
||||
"it",
|
||||
"pt",
|
||||
"pl",
|
||||
"tr",
|
||||
"ru",
|
||||
"nl",
|
||||
"cs",
|
||||
"ar",
|
||||
"zh",
|
||||
"hu",
|
||||
"ko",
|
||||
"ja",
|
||||
]
|
||||
)
|
||||
tts_text = gr.Textbox(
|
||||
label="Input Text.",
|
||||
value="This model sounds really good and above all, it's reasonably fast.",
|
||||
)
|
||||
tts_btn = gr.Button(value="Step 4 - Inference")
|
||||
|
||||
with gr.Column() as col3:
|
||||
tts_output_audio = gr.Audio(label="Generated Audio.")
|
||||
reference_audio = gr.Audio(label="Reference audio used.")
|
||||
|
||||
prompt_compute_btn.click(
|
||||
fn=preprocess_dataset,
|
||||
inputs=[
|
||||
upload_file,
|
||||
lang,
|
||||
out_path,
|
||||
],
|
||||
outputs=[
|
||||
progress_data,
|
||||
train_csv,
|
||||
eval_csv,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
train_btn.click(
|
||||
fn=train_model,
|
||||
inputs=[
|
||||
lang,
|
||||
train_csv,
|
||||
eval_csv,
|
||||
num_epochs,
|
||||
batch_size,
|
||||
out_path,
|
||||
],
|
||||
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
||||
)
|
||||
|
||||
load_btn.click(
|
||||
fn=load_model,
|
||||
inputs=[
|
||||
xtts_checkpoint,
|
||||
xtts_config,
|
||||
xtts_vocab
|
||||
],
|
||||
outputs=[progress_load],
|
||||
)
|
||||
|
||||
tts_btn.click(
|
||||
fn=run_tts,
|
||||
inputs=[
|
||||
tts_language,
|
||||
tts_text,
|
||||
speaker_reference_audio,
|
||||
],
|
||||
outputs=[tts_output_audio, reference_audio],
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""XTTS fine-tuning demo\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
|
||||
""",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Port to run the gradio demo. Default: 5003",
|
||||
default=5003,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_path",
|
||||
type=str,
|
||||
help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/",
|
||||
default="/tmp/xtts_ft/",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
help="Number of epochs to train. Default: 10",
|
||||
default=10,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
help="Batch size. Default: 4",
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad_acumm",
|
||||
type=int,
|
||||
help="Grad accumulation steps. Default: 1",
|
||||
default=1,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Tab("Data processing"):
|
||||
out_path = gr.Textbox(
|
||||
label="Output path (where data and checkpoints will be saved):",
|
||||
value=args.out_path,
|
||||
)
|
||||
# upload_file = gr.Audio(
|
||||
# sources="upload",
|
||||
# label="Select here the audio files that you want to use for XTTS trainining !",
|
||||
# type="filepath",
|
||||
# )
|
||||
upload_file = gr.File(
|
||||
file_count="multiple",
|
||||
label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)",
|
||||
)
|
||||
lang = gr.Dropdown(
|
||||
label="Dataset Language",
|
||||
value="en",
|
||||
choices=[
|
||||
"en",
|
||||
"es",
|
||||
"fr",
|
||||
"de",
|
||||
"it",
|
||||
"pt",
|
||||
"pl",
|
||||
"tr",
|
||||
"ru",
|
||||
"nl",
|
||||
"cs",
|
||||
"ar",
|
||||
"zh",
|
||||
"hu",
|
||||
"ko",
|
||||
"ja"
|
||||
],
|
||||
)
|
||||
progress_data = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
logs = gr.Textbox(
|
||||
label="Logs:",
|
||||
interactive=False,
|
||||
)
|
||||
demo.load(read_logs, None, logs, every=1)
|
||||
|
||||
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.")
|
||||
|
||||
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
|
||||
clear_gpu_cache()
|
||||
out_path = os.path.join(out_path, "dataset")
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
if audio_path is None:
|
||||
# ToDo: raise an error
|
||||
pass
|
||||
else:
|
||||
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
||||
|
||||
clear_gpu_cache()
|
||||
|
||||
# if audio total len is less than 2 minutes raise an error
|
||||
if audio_total_size < 120:
|
||||
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
|
||||
print(message)
|
||||
return message, " ", " "
|
||||
|
||||
print("Dataset Processed!")
|
||||
return "Dataset Processed!", train_meta, eval_meta
|
||||
|
||||
with gr.Tab("Fine-tuning XTTS Encoder"):
|
||||
train_csv = gr.Textbox(
|
||||
label="Train CSV:",
|
||||
)
|
||||
eval_csv = gr.Textbox(
|
||||
label="Eval CSV:",
|
||||
)
|
||||
num_epochs = gr.Slider(
|
||||
label="Number of epochs:",
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
step=1,
|
||||
value=args.num_epochs,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
label="Batch size:",
|
||||
minimum=2,
|
||||
maximum=512,
|
||||
step=1,
|
||||
value=args.batch_size,
|
||||
)
|
||||
grad_acumm = gr.Slider(
|
||||
label="Grad accumulation steps:",
|
||||
minimum=2,
|
||||
maximum=128,
|
||||
step=1,
|
||||
value=args.grad_acumm,
|
||||
)
|
||||
progress_train = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
logs_tts_train = gr.Textbox(
|
||||
label="Logs:",
|
||||
interactive=False,
|
||||
)
|
||||
demo.load(read_logs, None, logs_tts_train, every=1)
|
||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||
|
||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path):
|
||||
clear_gpu_cache()
|
||||
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path)
|
||||
# copy original files to avoid parameters changes issues
|
||||
os.system(f"cp {config_path} {exp_path}")
|
||||
os.system(f"cp {vocab_file} {exp_path}")
|
||||
|
||||
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
|
||||
print("Model training done!")
|
||||
clear_gpu_cache()
|
||||
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
||||
|
||||
|
||||
with gr.Tab("Inference"):
|
||||
with gr.Row():
|
||||
with gr.Column() as col1:
|
||||
xtts_checkpoint = gr.Textbox(
|
||||
label="XTTS checkpoint path:",
|
||||
value="",
|
||||
)
|
||||
xtts_config = gr.Textbox(
|
||||
label="XTTS config path:",
|
||||
value="",
|
||||
)
|
||||
xtts_vocab = gr.Textbox(
|
||||
label="XTTS config path:",
|
||||
value="",
|
||||
)
|
||||
progress_load = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
load_btn = gr.Button(value="Step 3 - Load Fine tuned XTTS model")
|
||||
|
||||
with gr.Column() as col2:
|
||||
speaker_reference_audio = gr.Textbox(
|
||||
label="Speaker reference audio:",
|
||||
value="",
|
||||
)
|
||||
tts_language = gr.Dropdown(
|
||||
label="Language",
|
||||
value="en",
|
||||
choices=[
|
||||
"en",
|
||||
"es",
|
||||
"fr",
|
||||
"de",
|
||||
"it",
|
||||
"pt",
|
||||
"pl",
|
||||
"tr",
|
||||
"ru",
|
||||
"nl",
|
||||
"cs",
|
||||
"ar",
|
||||
"zh",
|
||||
"hu",
|
||||
"ko",
|
||||
"ja",
|
||||
]
|
||||
)
|
||||
tts_text = gr.Textbox(
|
||||
label="Input Text.",
|
||||
value="This model sounds really good and above all, it's reasonably fast.",
|
||||
)
|
||||
tts_btn = gr.Button(value="Step 4 - Inference")
|
||||
|
||||
with gr.Column() as col3:
|
||||
tts_output_audio = gr.Audio(label="Generated Audio.")
|
||||
reference_audio = gr.Audio(label="Reference audio used.")
|
||||
|
||||
prompt_compute_btn.click(
|
||||
fn=preprocess_dataset,
|
||||
inputs=[
|
||||
upload_file,
|
||||
lang,
|
||||
out_path,
|
||||
],
|
||||
outputs=[
|
||||
progress_data,
|
||||
train_csv,
|
||||
eval_csv,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
train_btn.click(
|
||||
fn=train_model,
|
||||
inputs=[
|
||||
lang,
|
||||
train_csv,
|
||||
eval_csv,
|
||||
num_epochs,
|
||||
batch_size,
|
||||
grad_acumm,
|
||||
out_path,
|
||||
],
|
||||
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
||||
)
|
||||
|
||||
load_btn.click(
|
||||
fn=load_model,
|
||||
inputs=[
|
||||
xtts_checkpoint,
|
||||
xtts_config,
|
||||
xtts_vocab
|
||||
],
|
||||
outputs=[progress_load],
|
||||
)
|
||||
|
||||
tts_btn.click(
|
||||
fn=run_tts,
|
||||
inputs=[
|
||||
tts_language,
|
||||
tts_text,
|
||||
speaker_reference_audio,
|
||||
],
|
||||
outputs=[tts_output_audio, reference_audio],
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
share=True,
|
||||
debug=True,
|
||||
server_port=PORT,
|
||||
debug=False,
|
||||
server_port=args.port,
|
||||
server_name="0.0.0.0"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue