mirror of https://github.com/coqui-ai/TTS.git
Rebase bug fix and update recipe
This commit is contained in:
parent
affaf11148
commit
ec7f54768a
|
@ -1,7 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -12,13 +10,10 @@ from torch.utils.data import DataLoader
|
|||
from trainer.torch import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.configs.tortoise_config import TortoiseConfig
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
|
@ -456,7 +451,7 @@ class GPTTrainer(BaseTTS):
|
|||
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
|
||||
"""Load the model checkpoint and setup for training or inference"""
|
||||
|
||||
state, _ = self.xtts.get_compatible_checkpoint_state(checkpoint_path)
|
||||
state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path)
|
||||
|
||||
# load the model weights
|
||||
self.xtts.load_state_dict(state, strict=strict)
|
||||
|
|
|
@ -5,6 +5,8 @@ from trainer import Trainer, TrainerArgs
|
|||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
# Logging parameters
|
||||
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
||||
|
@ -22,7 +24,7 @@ BATCH_SIZE = 3 # set here the batch size
|
|||
GRAD_ACUMM_STEPS = 84 # set here the grad accumulation steps
|
||||
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
|
||||
|
||||
# Define here the dataset that you want to use for the fine tuning
|
||||
# Define here the dataset that you want to use for the fine-tuning on.
|
||||
config_dataset = BaseDatasetConfig(
|
||||
formatter="ljspeech",
|
||||
dataset_name="ljspeech",
|
||||
|
@ -31,20 +33,34 @@ config_dataset = BaseDatasetConfig(
|
|||
language="en",
|
||||
)
|
||||
|
||||
# Add here the configs of the datasets
|
||||
DATASETS_CONFIG_LIST = [config_dataset]
|
||||
|
||||
# ToDo: update with the latest released checkpoints
|
||||
# Define the path where XTTS v1.1.1 files will be downloaded
|
||||
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v1.1_original_model_files/")
|
||||
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||
|
||||
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
|
||||
DVAE_CHECKPOINT = "/raid/datasets/xtts_models/dvae.pth" # DVAE checkpoint
|
||||
MEL_NORM_FILE = (
|
||||
"/raid/datasets/xtts_models/mel_stats.pth" # Mel spectrogram norms, required for dvae mel spectrogram extraction
|
||||
)
|
||||
|
||||
# DVAE files
|
||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/denoising_dvae_v3_small.pth"
|
||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth"
|
||||
# download DVAE files
|
||||
print(" > Downloading DVAE files!")
|
||||
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||
|
||||
# Set the path to the downloaded files
|
||||
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1])
|
||||
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, MEL_NORM_LINK.split("/")[-1])
|
||||
|
||||
# Download XTTS v1.1 checkpoint
|
||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json"
|
||||
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth"
|
||||
print(" > Downloading XTTS v1.1 files!")
|
||||
ModelManager._download_model_files([TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||
|
||||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||
TOKENIZER_FILE = "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/tokenizer_merged_5.json" # vocab.json file
|
||||
XTTS_CHECKPOINT = "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth" # model.pth file
|
||||
|
||||
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
|
||||
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file
|
||||
|
||||
# Training sentences generations
|
||||
SPEAKER_REFERENCE = (
|
||||
|
@ -71,9 +87,11 @@ def main():
|
|||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
)
|
||||
# define audio config
|
||||
audio_config = XttsAudioConfig(
|
||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
||||
)
|
||||
# training parameters config
|
||||
config = GPTTrainerConfig(
|
||||
output_path=OUT_PATH,
|
||||
model_args=model_args,
|
||||
|
|
Loading…
Reference in New Issue