From cc4f37e1b0bb270b1cd5883b2de0cf3bee279f62 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 23 Nov 2023 16:30:49 -0300 Subject: [PATCH] Add training and inference columns --- TTS/demos/xtts_ft_demo/utils/gpt_train.py | 163 +++++++++++++++++++ TTS/demos/xtts_ft_demo/xtts_demo.py | 178 ++++++++++++++++++--- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 2 +- 3 files changed, 319 insertions(+), 24 deletions(-) create mode 100644 TTS/demos/xtts_ft_demo/utils/gpt_train.py diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py new file mode 100644 index 00000000..a4f5cb9a --- /dev/null +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -0,0 +1,163 @@ +import os + +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.utils.manage import ModelManager + + +def train_gpt(language, num_epochs, batch_size, train_csv, eval_csv, output_path): + # Logging parameters + RUN_NAME = "GPT_XTTSv2.1_FT" + PROJECT_NAME = "XTTS_trainer" + DASHBOARD_LOGGER = "tensorboard" + LOGGER_URI = None + + # Set here the path that the checkpoints will be saved. Default: ./run/training/ + OUT_PATH = os.path.join(output_path, "run", "training") + + + # Training Parameters + OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False + START_WITH_EVAL = True # if True it will star with evaluation + BATCH_SIZE = batch_size # set here the batch size + GRAD_ACUMM_STEPS = 1 # set here the grad accumulation steps + # Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly. + + + # Define here the dataset that you want to use for the fine-tuning on. + config_dataset = BaseDatasetConfig( + formatter="coqui", + dataset_name="ft_dataset", + path=os.path.dirname(train_csv), + meta_file_train=train_csv, + meta_file_val=eval_csv, + language=language, + ) + + # Add here the configs of the datasets + DATASETS_CONFIG_LIST = [config_dataset] + + # Define the path where XTTS v2.0.1 files will be downloaded + CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") + os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) + + + # DVAE files + DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" + MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" + + # Set the path to the downloaded files + DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK)) + MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK)) + + # download DVAE files if needed + if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): + print(" > Downloading DVAE files!") + ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) + + + # Download XTTS v2.0 checkpoint if needed + TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" + XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth" + XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json" + + # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. + TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file + XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file + XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file + + # download XTTS v2.0 files if needed + if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): + print(" > Downloading XTTS v2.0 files!") + ModelManager._download_model_files( + [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True + ) + + # init args and config + model_args = GPTArgs( + max_conditioning_length=132300, # 6 secs + min_conditioning_length=66150, # 3 secs + debug_loading_failures=False, + max_wav_length=255995, # ~11.6 seconds + max_text_length=200, + mel_norm_file=MEL_NORM_FILE, + dvae_checkpoint=DVAE_CHECKPOINT, + xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune + tokenizer_file=TOKENIZER_FILE, + gpt_num_audio_tokens=1026, + gpt_start_audio_token=1024, + gpt_stop_audio_token=1025, + gpt_use_masking_gt_prompt_approach=True, + gpt_use_perceiver_resampler=True, + ) + # define audio config + audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) + # training parameters config + config = GPTTrainerConfig( + epochs=num_epochs, + output_path=OUT_PATH, + model_args=model_args, + run_name=RUN_NAME, + project_name=PROJECT_NAME, + run_description=""" + GPT XTTS training + """, + dashboard_logger=DASHBOARD_LOGGER, + logger_uri=LOGGER_URI, + audio=audio_config, + batch_size=BATCH_SIZE, + batch_group_size=48, + eval_batch_size=BATCH_SIZE, + num_loader_workers=8, + eval_split_max_size=256, + print_step=50, + plot_step=100, + log_model_step=100, + save_step=1000, + save_n_checkpoints=1, + save_checkpoints=True, + # target_loss="loss", + print_eval=False, + # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters. + optimizer="AdamW", + optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS, + optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2}, + lr=5e-06, # learning rate + lr_scheduler="MultiStepLR", + # it was adjusted accordly for the new step scheme + lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, + test_sentences=[], + ) + + # init the model from config + model = GPTTrainer.init_from_config(config) + + # load training samples + train_samples, eval_samples = load_tts_samples( + DATASETS_CONFIG_LIST, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + # init the trainer and 🚀 + trainer = Trainer( + TrainerArgs( + restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter + skip_train_epoch=False, + start_with_eval=START_WITH_EVAL, + grad_accum_steps=GRAD_ACUMM_STEPS, + ), + config, + output_path=OUT_PATH, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + ) + trainer.fit() + + + return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer.output_path, train_samples[0]["audio_file"] diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py index 99b64792..7e6e1c09 100644 --- a/TTS/demos/xtts_ft_demo/xtts_demo.py +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -10,15 +10,50 @@ import os import torch import torchaudio from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list, list_audios +from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt + +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.models.xtts import Xtts -import logging PORT = 5003 +def load_model(xtts_checkpoint, xtts_config, xtts_vocab): + config = XttsConfig() + config.load_json(xtts_config) + model = Xtts.init_from_config(config) + print("Loading XTTS model! ") + model.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False) + if torch.cuda.is_available(): + model.cuda() + return model + +def run_tts(lang, tts_text, xtts_checkpoint, xtts_config, xtts_vocab, speaker_audio_file, state_vars): + # ToDo: add the load in other function to fast inference + model = load_model(xtts_checkpoint, xtts_config, xtts_vocab) + gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=model.config.gpt_cond_len, max_ref_length=model.config.max_ref_len, sound_norm_refs=model.config.sound_norm_refs) + speaker_embedding + out = model.inference( + text=tts_text, + language=lang, + gpt_cond_latent=gpt_cond_latent, + speaker_embedding=speaker_embedding, + temperature=model.config.temperature, # Add custom parameters here + length_penalty=model.config.length_penalty, + repetition_penalty=model.config.repetition_penalty, + top_k=model.config.top_k, + top_p=model.config.top_p, + ) + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: + out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) + out_path = fp.name + torchaudio.save(out_path, out["wav"], 24000) + + return out_path, speaker_audio_file + -def run_tts(lang, tts_text, state_vars, temperature, rms_norm_output=False): - return None # define a logger to redirect class Logger: @@ -43,6 +78,16 @@ sys.stdout = Logger() sys.stderr = sys.stdout +# logging.basicConfig(stream=sys.stdout, level=logging.INFO) +import logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.StreamHandler(sys.stdout) + ] +) + def read_logs(): sys.stdout.flush() with open(sys.stdout.log_file, "r") as f: @@ -82,8 +127,8 @@ with gr.Blocks() as demo: "ja" ], ) - voice_ready = gr.Label( - label="Progress." + progress_data = gr.Label( + label="Progress:" ) logs = gr.Textbox( label="Logs:", @@ -94,23 +139,78 @@ with gr.Blocks() as demo: prompt_compute_btn = gr.Button(value="Step 1 - Create dataset.") with gr.Column() as col2: - + num_epochs = gr.Slider( + label="num_epochs", + minimum=1, + maximum=100, + step=1, + value=2,# 15 + ) + batch_size = gr.Slider( + label="batch_size", + minimum=2, + maximum=512, + step=1, + value=15, + ) + progress_train = gr.Label( + label="Progress:" + ) + logs_tts_train = gr.Textbox( + label="Logs:", + interactive=False, + ) + demo.load(read_logs, None, logs_tts_train, every=1) + train_btn = gr.Button(value="Step 2 - Run the training") + + with gr.Column() as col3: + xtts_checkpoint = gr.Textbox( + label="XTTS checkpoint path:", + value="", + ) + xtts_config = gr.Textbox( + label="XTTS config path:", + value="", + ) + xtts_vocab = gr.Textbox( + label="XTTS config path:", + value="", + ) + speaker_reference_audio = gr.Textbox( + label="Speaker reference audio:", + value="", + ) + tts_language = gr.Dropdown( + label="Language", + value="en", + choices=[ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh", + "hu", + "ko", + "ja", + ] + ) tts_text = gr.Textbox( label="Input Text.", value="This model sounds really good and above all, it's reasonably fast.", ) - temperature = gr.Slider( - label="temperature", minimum=0.00001, maximum=1.0, step=0.05, value=0.75 - ) - rms_norm_output = gr.Checkbox( - label="RMS norm output.", value=True, interactive=True - ) - tts_btn = gr.Button(value="Step 2 - TTS") + tts_btn = gr.Button(value="Step 3 - Inference XTTS model") + + tts_output_audio = gr.Audio(label="Generated Audio.") + reference_audio = gr.Audio(label="Reference audio used.") - with gr.Column() as col3: - tts_output_audio_no_enhanced = gr.Audio(label="HiFi-GAN.") - tts_output_audio_no_enhanced_ft = gr.Audio(label="HiFi-GAN new.") - reference_audio = gr.Audio(label="Reference Speech used.") def preprocess_dataset(audio_path, language, state_vars, progress=gr.Progress(track_tqdm=True)): # create a temp directory to save the dataset @@ -119,12 +219,12 @@ with gr.Blocks() as demo: # ToDo: raise an error pass else: - train_meta, eval_meta = format_audio_list([audio_path], target_language=language, out_path=out_path, gradio_progress=progress) state_vars = {} state_vars["train_csv"] = train_meta state_vars["eval_csv"] = eval_meta + print(state_vars) return "Dataset Processed!", state_vars prompt_compute_btn.click( @@ -135,23 +235,55 @@ with gr.Blocks() as demo: state_vars, ], outputs=[ - voice_ready, + progress_data, state_vars, ], ) + + def train_model(language, num_epochs, batch_size, state_vars, output_path="./", progress=gr.Progress(track_tqdm=True)): + # state_vars = {'train_csv': '/tmp/tmprh4k_vou/metadata_train.csv', 'eval_csv': '/tmp/tmprh4k_vou/metadata_eval.csv'} + config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, state_vars["train_csv"], state_vars["eval_csv"], output_path=output_path) + # copy original files to avoid parameters changes issues + os.system(f"cp {config_path} {exp_path}") + os.system(f"cp {vocab_file} {exp_path}") + + ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") + state_vars["config_path"] = config_path + state_vars["original_xtts_checkpoint"] = original_xtts_checkpoint + state_vars["vocab_file"] = vocab_file + state_vars["ft_xtts_checkpoint"] = ft_xtts_checkpoint + state_vars["speaker_audio_file"] = speaker_wav + return "Model training done!", state_vars, config_path, vocab_file, ft_xtts_checkpoint, speaker_wav + + + train_btn.click( + fn=train_model, + inputs=[ + lang, + num_epochs, + batch_size, + state_vars, + ], + outputs=[progress_train, state_vars, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], + ) + tts_btn.click( fn=run_tts, inputs=[ - lang, + tts_language, tts_text, + xtts_checkpoint, + xtts_config, + xtts_vocab, + speaker_reference_audio, state_vars, - temperature, - rms_norm_output, ], - outputs=[tts_output_audio_no_enhanced, tts_output_audio_no_enhanced_ft], + outputs=[tts_output_audio, reference_audio], ) + + if __name__ == "__main__": demo.launch( share=True, diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 4789e1f4..671be8eb 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -225,11 +225,11 @@ class GPTTrainer(BaseTTS): @torch.no_grad() def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 + test_audios = {} if self.config.test_sentences: # init gpt for inference mode self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) self.xtts.gpt.eval() - test_audios = {} print(" | > Synthesizing test sentences.") for idx, s_info in enumerate(self.config.test_sentences): wav = self.xtts.synthesize(