diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 22577ad4..e93063fa 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -1,7 +1,5 @@ -import os -import sys from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple, Union import torch import torch.nn as nn @@ -12,13 +10,10 @@ from torch.utils.data import DataLoader from trainer.torch import DistributedSampler from trainer.trainer_utils import get_optimizer, get_scheduler -from TTS.tts.configs.tortoise_config import TortoiseConfig from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram from TTS.tts.layers.xtts.dvae import DiscreteVAE -from TTS.tts.layers.xtts.gpt import GPT -from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.models.base_tts import BaseTTS @@ -456,7 +451,7 @@ class GPTTrainer(BaseTTS): ): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin """Load the model checkpoint and setup for training or inference""" - state, _ = self.xtts.get_compatible_checkpoint_state(checkpoint_path) + state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path) # load the model weights self.xtts.load_state_dict(state, strict=strict) diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index 641d050c..6fb1c221 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -5,6 +5,8 @@ from trainer import Trainer, TrainerArgs from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.utils.manage import ModelManager + # Logging parameters RUN_NAME = "GPT_XTTS_LJSpeech_FT" @@ -22,7 +24,7 @@ BATCH_SIZE = 3 # set here the batch size GRAD_ACUMM_STEPS = 84 # set here the grad accumulation steps # Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly. -# Define here the dataset that you want to use for the fine tuning +# Define here the dataset that you want to use for the fine-tuning on. config_dataset = BaseDatasetConfig( formatter="ljspeech", dataset_name="ljspeech", @@ -31,20 +33,34 @@ config_dataset = BaseDatasetConfig( language="en", ) +# Add here the configs of the datasets DATASETS_CONFIG_LIST = [config_dataset] -# ToDo: update with the latest released checkpoints +# Define the path where XTTS v1.1.1 files will be downloaded +CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v1.1_original_model_files/") +os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) -# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model -DVAE_CHECKPOINT = "/raid/datasets/xtts_models/dvae.pth" # DVAE checkpoint -MEL_NORM_FILE = ( - "/raid/datasets/xtts_models/mel_stats.pth" # Mel spectrogram norms, required for dvae mel spectrogram extraction -) + +# DVAE files +DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/denoising_dvae_v3_small.pth" +MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth" +# download DVAE files +print(" > Downloading DVAE files!") +ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) + +# Set the path to the downloaded files +DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1]) +MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, MEL_NORM_LINK.split("/")[-1]) + +# Download XTTS v1.1 checkpoint +TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json" +XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth" +print(" > Downloading XTTS v1.1 files!") +ModelManager._download_model_files([TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. -TOKENIZER_FILE = "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/tokenizer_merged_5.json" # vocab.json file -XTTS_CHECKPOINT = "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth" # model.pth file - +TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file +XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file # Training sentences generations SPEAKER_REFERENCE = ( @@ -71,9 +87,11 @@ def main(): gpt_start_audio_token=8192, gpt_stop_audio_token=8193, ) + # define audio config audio_config = XttsAudioConfig( sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 ) + # training parameters config config = GPTTrainerConfig( output_path=OUT_PATH, model_args=model_args,