mirror of https://github.com/coqui-ai/TTS.git
Rebase bug fix and update recipe
This commit is contained in:
parent
affaf11148
commit
ec7f54768a
|
@ -1,7 +1,5 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -12,13 +10,10 @@ from torch.utils.data import DataLoader
|
||||||
from trainer.torch import DistributedSampler
|
from trainer.torch import DistributedSampler
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
from TTS.tts.configs.tortoise_config import TortoiseConfig
|
|
||||||
from TTS.tts.configs.xtts_config import XttsConfig
|
from TTS.tts.configs.xtts_config import XttsConfig
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||||
from TTS.tts.layers.xtts.gpt import GPT
|
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
@ -456,7 +451,7 @@ class GPTTrainer(BaseTTS):
|
||||||
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
|
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
|
||||||
"""Load the model checkpoint and setup for training or inference"""
|
"""Load the model checkpoint and setup for training or inference"""
|
||||||
|
|
||||||
state, _ = self.xtts.get_compatible_checkpoint_state(checkpoint_path)
|
state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path)
|
||||||
|
|
||||||
# load the model weights
|
# load the model weights
|
||||||
self.xtts.load_state_dict(state, strict=strict)
|
self.xtts.load_state_dict(state, strict=strict)
|
||||||
|
|
|
@ -5,6 +5,8 @@ from trainer import Trainer, TrainerArgs
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||||
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
|
|
||||||
# Logging parameters
|
# Logging parameters
|
||||||
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
||||||
|
@ -22,7 +24,7 @@ BATCH_SIZE = 3 # set here the batch size
|
||||||
GRAD_ACUMM_STEPS = 84 # set here the grad accumulation steps
|
GRAD_ACUMM_STEPS = 84 # set here the grad accumulation steps
|
||||||
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
|
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
|
||||||
|
|
||||||
# Define here the dataset that you want to use for the fine tuning
|
# Define here the dataset that you want to use for the fine-tuning on.
|
||||||
config_dataset = BaseDatasetConfig(
|
config_dataset = BaseDatasetConfig(
|
||||||
formatter="ljspeech",
|
formatter="ljspeech",
|
||||||
dataset_name="ljspeech",
|
dataset_name="ljspeech",
|
||||||
|
@ -31,20 +33,34 @@ config_dataset = BaseDatasetConfig(
|
||||||
language="en",
|
language="en",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add here the configs of the datasets
|
||||||
DATASETS_CONFIG_LIST = [config_dataset]
|
DATASETS_CONFIG_LIST = [config_dataset]
|
||||||
|
|
||||||
# ToDo: update with the latest released checkpoints
|
# Define the path where XTTS v1.1.1 files will be downloaded
|
||||||
|
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v1.1_original_model_files/")
|
||||||
|
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||||
|
|
||||||
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
|
|
||||||
DVAE_CHECKPOINT = "/raid/datasets/xtts_models/dvae.pth" # DVAE checkpoint
|
# DVAE files
|
||||||
MEL_NORM_FILE = (
|
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/denoising_dvae_v3_small.pth"
|
||||||
"/raid/datasets/xtts_models/mel_stats.pth" # Mel spectrogram norms, required for dvae mel spectrogram extraction
|
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth"
|
||||||
)
|
# download DVAE files
|
||||||
|
print(" > Downloading DVAE files!")
|
||||||
|
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||||
|
|
||||||
|
# Set the path to the downloaded files
|
||||||
|
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1])
|
||||||
|
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, MEL_NORM_LINK.split("/")[-1])
|
||||||
|
|
||||||
|
# Download XTTS v1.1 checkpoint
|
||||||
|
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json"
|
||||||
|
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth"
|
||||||
|
print(" > Downloading XTTS v1.1 files!")
|
||||||
|
ModelManager._download_model_files([TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||||
|
|
||||||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||||
TOKENIZER_FILE = "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/tokenizer_merged_5.json" # vocab.json file
|
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
|
||||||
XTTS_CHECKPOINT = "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth" # model.pth file
|
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file
|
||||||
|
|
||||||
|
|
||||||
# Training sentences generations
|
# Training sentences generations
|
||||||
SPEAKER_REFERENCE = (
|
SPEAKER_REFERENCE = (
|
||||||
|
@ -71,9 +87,11 @@ def main():
|
||||||
gpt_start_audio_token=8192,
|
gpt_start_audio_token=8192,
|
||||||
gpt_stop_audio_token=8193,
|
gpt_stop_audio_token=8193,
|
||||||
)
|
)
|
||||||
|
# define audio config
|
||||||
audio_config = XttsAudioConfig(
|
audio_config = XttsAudioConfig(
|
||||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
||||||
)
|
)
|
||||||
|
# training parameters config
|
||||||
config = GPTTrainerConfig(
|
config = GPTTrainerConfig(
|
||||||
output_path=OUT_PATH,
|
output_path=OUT_PATH,
|
||||||
model_args=model_args,
|
model_args=model_args,
|
||||||
|
|
Loading…
Reference in New Issue