From c4ceaabe2cc85654d47c902779766e8e985a010d Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 16 Oct 2023 15:32:00 -0300 Subject: [PATCH] Add test sentences during the training --- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 94 ++++++++++++---------- TTS/tts/models/xtts.py | 2 +- recipes/multilingual/xtts_v1/train_xtts.py | 25 +++--- 3 files changed, 67 insertions(+), 54 deletions(-) diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index d884f12a..87b1228e 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -12,7 +12,8 @@ import sys from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.gpt import GPT -from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig +from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig, Xtts +from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.base_tts import BaseTTS from coqpit import Coqpit @@ -25,20 +26,21 @@ from TTS.tts.datasets.dataset import TTSDataset from trainer.torch import DistributedSampler from trainer.trainer_utils import get_optimizer, get_scheduler - from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.utils.io import load_fsspec from TTS.tts.layers.xtts.dvae import DiscreteVAE +from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder + @dataclass -class GPTConfig(TortoiseConfig): +class GPTTrainerConfig(XttsConfig): lr: float = 5e-06 training_seed: int = 1 optimizer_wd_only_on_weights: bool = False weighted_loss_attrs: dict = field(default_factory=lambda: {}) weighted_loss_multipliers: dict = field(default_factory=lambda: {}) - + test_sentences: List[dict] = field(default_factory=lambda: []) @dataclass class XttsAudioConfig(XttsAudioConfig): @@ -58,7 +60,8 @@ class GPTArgs(XttsArgs): tokenizer_file: str = "" mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth" dvae_checkpoint: str = "" - gpt_checkpoint: str = "" + xtts_checkpoint: str = "" + gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model vocoder: str = "" # overide vocoder key on the config to avoid json write issues @@ -80,28 +83,18 @@ class GPTTrainer(BaseTTS): """ super().__init__(config, ap=None, tokenizer=None) self.config = config + # init XTTS model + self.xtts = Xtts(self.config) + # create the tokenizer with the target vocabulary + self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file) + # init gpt encoder and hifigan decoder + self.xtts.init_models() + # set mel stats + if self.args.mel_norm_file: + self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file) - self.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file) - - self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size() - self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]") - self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]") - - self.gpt = GPT( - layers=self.args.gpt_layers, - model_dim=self.args.gpt_n_model_channels, - start_text_token=self.args.gpt_start_text_token, - stop_text_token=self.args.gpt_stop_text_token, - heads=self.args.gpt_n_heads, - max_text_tokens=self.args.gpt_max_text_tokens, - max_mel_tokens=self.args.gpt_max_audio_tokens, - max_prompt_tokens=self.args.gpt_max_prompt_tokens, - number_text_tokens=self.args.gpt_number_text_tokens, - num_audio_tokens=self.args.gpt_num_audio_tokens, - start_audio_token=self.args.gpt_start_audio_token, - stop_audio_token=self.args.gpt_stop_audio_token, - ).cuda() - + if self.args.xtts_checkpoint: + self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False) # load GPT if available if self.args.gpt_checkpoint: @@ -122,8 +115,8 @@ class GPTTrainer(BaseTTS): del gpt_checkpoint[key] # edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible - if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.gpt.text_embedding.weight.shape: - num_new_tokens = self.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] + if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape: + num_new_tokens = self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") # add new tokens to a linear layer (text_head) @@ -137,7 +130,7 @@ class GPTTrainer(BaseTTS): # add new weights to the linear layer (text_head) text_head_weight = gpt_checkpoint["text_head.weight"] start_token_row = text_head_weight[-1, :] - new_entry = torch.randn(num_new_tokens, self.gpt.text_head.weight.shape[1]) + new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1]) text_head_weight = torch.cat([text_head_weight, new_entry], axis=0) text_head_weight[-1, :] = start_token_row gpt_checkpoint["text_head.weight"] = text_head_weight @@ -150,10 +143,8 @@ class GPTTrainer(BaseTTS): text_head_bias[-1] = start_token_row gpt_checkpoint["text_head.bias"] = text_head_bias - self.gpt.load_state_dict(gpt_checkpoint, strict=True) + self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True) print(">> GPT weights restored from:", self.args.gpt_checkpoint) - else: - print(">> GPT weights randomly initialized! If you want you can specify a checkpoint in config.model_args.gpt_checkpoint") # Mel spectrogram extractor for conditioning self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram( @@ -195,6 +186,7 @@ class GPTTrainer(BaseTTS): # Mel spectrogram extractor for DVAE self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate) + @property def device(self): return next(self.parameters()).device @@ -211,12 +203,30 @@ class GPTTrainer(BaseTTS): cond_mels: MEL float tensor, (b, num_samples, 80,t_m) cond_idxs: cond start and end indexs, (b, 2) """ - losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs) + losses = self.xtts.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs) return losses @torch.no_grad() def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 - return {}, {} + if self.config.test_sentences: + # init gpt for inference mode + self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) + self.xtts.gpt.eval() + test_audios = {} + print(" | > Synthesizing test sentences.") + for idx, s_info in enumerate(self.config.test_sentences): + wav = self.xtts.synthesize(s_info["text"], self.config, s_info["speaker_wav"], s_info["language"])["wav"] + test_audios["{}-audio".format(idx)] = wav + + # delete inference layers + del self.xtts.gpt.gpt_inference + del self.xtts.gpt.gpt.wte + return {"audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate) def format_batch(self, batch: Dict) -> Dict: return batch @@ -323,7 +333,7 @@ class GPTTrainer(BaseTTS): loader = None else: # init dataloader - dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval) + dataset = XTTSDataset(self.config, samples, self.xtts.tokenizer, config.audio.sample_rate, is_eval) # wait all the DDP process to be ready if num_gpus > 1: @@ -362,7 +372,7 @@ class GPTTrainer(BaseTTS): # ToDo: deal with multi GPU training if self.config.optimizer_wd_only_on_weights: # parameters to only GPT model - net = self.gpt + net = self.xtts.gpt # normalizations norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, @@ -410,7 +420,7 @@ class GPTTrainer(BaseTTS): self.config.optimizer_params, self.config.lr, # optimize only for the GPT model - parameters=self.gpt.parameters(), + parameters=self.xtts.gpt.parameters(), ) def get_scheduler(self, optimizer) -> List: @@ -432,21 +442,21 @@ class GPTTrainer(BaseTTS): target_options={"anon": True}, ): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin """Load the model checkpoint and setup for training or inference""" - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))["model"] # load the model weights - self.gpt.load_state_dict(state, strict=strict) + self.xtts.load_state_dict(state, strict=strict) if eval: + self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) self.eval() - self.set_inference() assert not self.training @staticmethod - def init_from_config(config: "GPTConfig", samples: Union[List[List], List[Dict]] = None): + def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: - config (GPTConfig): Model config. + config (GPTTrainerConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. """ diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 40e8f946..3e609799 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -387,7 +387,7 @@ class Xtts(BaseTTS): audio = load_audio(audio_path) audio = audio[:, : 22050 * length] mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu()) - cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False) + cond_latent = self.gpt.get_style_emb(mel.to(self.device)) return cond_latent.transpose(1, 2) @torch.inference_mode() diff --git a/recipes/multilingual/xtts_v1/train_xtts.py b/recipes/multilingual/xtts_v1/train_xtts.py index 429e4e3a..f36bf1ae 100644 --- a/recipes/multilingual/xtts_v1/train_xtts.py +++ b/recipes/multilingual/xtts_v1/train_xtts.py @@ -3,7 +3,7 @@ from trainer import Trainer, TrainerArgs from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTConfig +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTTrainerConfig config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig( @@ -265,21 +265,21 @@ def main(): debug_loading_failures=False, max_wav_length=255995, # ~11.6 seconds max_text_length=200, - tokenizer_file="/raid/datasets/xtts_models/vocab.json", mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth", dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth", - gpt_checkpoint="/raid/datasets/xtts_models/gpt.pth", + tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune + xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", # checkpoint path of the model that you want to fine-tune gpt_num_audio_tokens=8194, gpt_start_audio_token=8192, gpt_stop_audio_token=8193, ) audio_config = XttsAudioConfig( - sample_rate=22050, # autoregressive SR + sample_rate=22050, # GPT SR dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 ) - config = GPTConfig( + config = GPTTrainerConfig( output_path=OUT_PATH, model_args=model_args, run_name=RUN_NAME, @@ -313,6 +313,10 @@ def main(): lr_scheduler="MultiStepLR", # it was adjusted accordly for the new step scheme lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, + test_sentences=[ + {"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"}, + {"text": "This cake is great. It's so delicious and moist.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"}, + ] ) # init the model from config @@ -341,7 +345,7 @@ def main(): if __name__ == "__main__": RUN_NAME = "GPT_XTTS" - PROJECT_NAME = "XTTS" + PROJECT_NAME = "XTTS_trainer" OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_style_emb/" DASHBOARD_LOGGER = "clearml" LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/" @@ -352,12 +356,11 @@ if __name__ == "__main__": GRAD_ACUMM_STEPS = 28 # debug - DASHBOARD_LOGGER = "tensorboard" - LOGGER_URI = None - RESTORE_PATH = None - BATCH_SIZE = 10 + # DASHBOARD_LOGGER = "tensorboard" + # LOGGER_URI = None + # RESTORE_PATH = None + BATCH_SIZE = 2 GRAD_ACUMM_STEPS = 1 - NUM_LOADERS = 1