diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py index 94838e91..fa07ff35 100644 --- a/TTS/tts/models/bark.py +++ b/TTS/tts/models/bark.py @@ -63,7 +63,9 @@ class BarkDataset(Dataset): return { "raw_text": raw_text, + "text_len": len(raw_text), "wav": wav, + "wav_len": wav.shape[1], "wav_file": wav_filename, "speaker_name": item["speaker_name"], "language_name": item["language"], @@ -143,9 +145,9 @@ class Bark(BaseTTS): super().__init__(config=config, ap=None, tokenizer=None, speaker_manager=None, language_manager=None) self.config.num_chars = len(tokenizer) self.tokenizer = tokenizer - self.semantic_model = GPT(config.semantic_config) - self.coarse_model = GPT(config.coarse_config) - self.fine_model = FineGPT(config.fine_config) + self.semantic_model = GPT(config.semantic_gpt_config) + self.coarse_model = GPT(config.coarse_gpt_config) + self.fine_model = FineGPT(config.fine_gpt_config) self.encodec = EncodecModel.encodec_model_24khz() self.encodec.set_target_bandwidth(6.0) self.semantic_tokenizer = BarkHubertAudioTokenizer(self.config, lazy_load=self.config.training_mode) @@ -154,11 +156,19 @@ class Bark(BaseTTS): def device(self): return next(self.parameters()).device + @property + def pad_token(self): + if self.config.training_mode == "semantic": + return self.config.SEMANTIC_PAD_TOKEN + elif self.config.training_mode in ["coarse", "fine"]: + return self.config.COARSE_SEMANTIC_PAD_TOKEN + else: + raise ValueError("Invalid training mode: {}".format(self.config.training_mode)) + def load_bark_models(self): self.semantic_model, self.config = load_model( ckpt_path=self.config.LOCAL_MODEL_PATHS["text"], device=self.device, config=self.config, model_type="text" ) - self.coarse_model, self.config = load_model( ckpt_path=self.config.LOCAL_MODEL_PATHS["coarse"], device=self.device, @@ -169,17 +179,19 @@ class Bark(BaseTTS): ckpt_path=self.config.LOCAL_MODEL_PATHS["fine"], device=self.device, config=self.config, model_type="fine" ) - def generate_coarse_fine_tokens(self, audio): + def generate_coarse_fine_tokens( + self, + audio, + ): if isinstance(audio, str): audio, sr = torchaudio.load(audio) - audio = convert_audio_and_make_label(audio, sr, self.config.sample_rate, self.encodec.channels) + audio = convert_audio(audio, sr, self.config.sample_rate, self.encodec.channels) audio = audio.unsqueeze(0).to(self.device) # Coarse and fine tokens with torch.no_grad(): encoded_frames = self.encodec.encode(audio) codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T] - codes = codes.cpu().numpy() return codes, codes[:2, :] # fine, corse def generate_semantic_tokens(self, audio): @@ -354,14 +366,6 @@ class Bark(BaseTTS): Returns: formatted batch """ - PAD_TOKEN = None - if self.config.training_mode == "semantic": - PAD_TOKEN = self.config.SEMANTIC_PAD_TOKEN - elif self.config.training_mode in ["coarse", "fine"]: - PAD_TOKEN = self.config.COARSE_SEMANTIC_PAD_TOKEN - else: - raise ValueError("Invalid training mode: {}".format(self.config.training_mode)) - tokenss = [] max_len = 0 for i, text in enumerate(batch["raw_text"]): @@ -370,11 +374,12 @@ class Bark(BaseTTS): tokenss.append(tokens) max_len = max(max_len, len(tokens)) - # pad and collate into batch - for i, tokens in enumerate(tokenss): - tokenss[i] = torch.nn.functional.pad(tokens, (0, max_len - len(tokens)), value=PAD_TOKEN) - tokens = torch.stack(tokenss, dim=0) - batch["input_ids"] = tokens[:, : self.config.max_text_tokens_len] + if self.config.training_mode == "semantic": + # pad and collate into batch + for i, tokens in enumerate(tokenss): + tokenss[i] = torch.nn.functional.pad(tokens, (0, max_len - len(tokens)), value=self.pad_token) + tokens = torch.stack(tokenss, dim=0) + batch["input_ids"] = tokens[:, : self.config.train_semantic_data_settings["max_text_tokens_len"]] return batch def format_batch_on_device(self, batch): @@ -386,11 +391,36 @@ class Bark(BaseTTS): Returns: dict: Formatted batch. """ + # TODO: Make padding and truncation based on exact length of the waveforms if self.config.training_mode == "semantic": batch["semantic_tokens"] = self.generate_semantic_tokens(batch["waveform"][:, 0])[ :, : self.config.max_semantic_tokens_len ] + elif self.config.training_mode == "coarse": + semantic_to_coarse_ratio = ( + self.config.COARSE_RATE_HZ / self.config.SEMANTIC_RATE_HZ * self.config.N_COARSE_CODEBOOKS + ) + batch["semantic_tokens"] = self.generate_semantic_tokens(batch["waveform"][:, 0])[ + :, : self.config.train_coarse_data_settings["max_semantic_tokens_len"] + ] + batch["semantic_tokens"] = torch.nn.functional.pad( + batch["semantic_tokens"], (0, 1), value=self.config.COARSE_INFER_TOKEN + ) + + batch["coarse_tokens"] = self.generate_coarse_fine_tokens(batch["waveform"])[1] + batch["coarse_tokens"] = ( + batch["coarse_tokens"].flatten(start_dim=1) + + self.config.CODEBOOK_SIZE + + self.config.SEMANTIC_VOCAB_SIZE + ) + batch["coarse_tokens"] = batch["coarse_tokens"][ + :, : self.config.train_coarse_data_settings["max_coarse_tokens_len"] + ] + elif self.config.training_mode == "fine": + batch["coarse_tokens"], batch["fine_tokens"] = self.generate_coarse_fine_tokens(batch["waveform"])[ + :, : self.config.max_coarse_tokens_len + ] return batch def train_step_semantic(self, batch: dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]: @@ -402,16 +432,26 @@ class Bark(BaseTTS): inputs = torch.cat([batch["input_ids"], input_tokens], dim=1) logits = self.semantic_model(inputs) - semantic_logits = logits[:, batch["input_ids"].size(1) :].contiguous() + logits = logits[:, batch["input_ids"].size(1) :].contiguous() - loss = criterion( - semantic_logits.view(-1, self.config.semantic_config.output_vocab_size), target_tokens.view(-1) - ) + loss = criterion(logits.view(-1, self.config.semantic_gpt_config.output_vocab_size), target_tokens.view(-1)) loss_dict = {"loss": loss} return {}, loss_dict - def train_step_coarse(self): - ... + def train_step_coarse(self, batch: dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]: + """Train coarse encoder""" + tokens = batch["coarse_tokens"] + target_tokens = tokens[:, 1:].contiguous() + input_tokens = tokens[:, :-1].contiguous() + + inputs = torch.cat([batch["semantic_tokens"], input_tokens], dim=1) + logits = self.coarse_model(inputs) + + logits = logits[:, batch["semantic_tokens"].size(1) :].contiguous() + + loss = criterion(logits.view(-1, self.config.coarse_gpt_config.output_vocab_size), target_tokens.view(-1)) + loss_dict = {"loss": loss} + return {}, loss_dict def train_step_fine(self): ... @@ -419,6 +459,10 @@ class Bark(BaseTTS): def train_step(self, *args, **kwargs): if self.config.training_mode == "semantic": return self.train_step_semantic(*args, **kwargs) + elif self.config.training_mode == "coarse": + return self.train_step_coarse(*args, **kwargs) + elif self.config.training_mode == "fine": + raise NotImplemented() def eval_step(self, *args, **kwargs): self.train_step(*args, **kwargs) @@ -436,26 +480,23 @@ class Bark(BaseTTS): return None def get_criterion(self): - if self.config.training_mode in ["semantic", "coarse_encoder"]: - return torch.nn.CrossEntropyLoss(ignore_index=self.config.COARSE_SEMANTIC_PAD_TOKEN) - elif self.config.training_mode == "fine_encoder": - return torch.nn.CrossEntropyLoss(ignore_index=self.config.FINE_COARSE_PAD_TOKEN) - else: - raise ValueError(f" ❗ Invalid training mode {self.config.training_mode}") + return torch.nn.CrossEntropyLoss(ignore_index=self.pad_token) def get_optimizer(self): if self.config.training_mode == "semantic": optimizer = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr, self.semantic_model ) - elif self.config.training_mode == "coarse_encoder": + elif self.config.training_mode == "coarse": optimizer = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr, self.coarse_model ) - elif self.config.training_mode == "fine_encoder": + elif self.config.training_mode == "fine": optimizer = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr, self.fine_model ) + else: + raise ValueError(" ❗ Invalid training mode: {}".format(self.config.training_mode)) return optimizer def get_scheduler(self, optimizer): @@ -546,12 +587,52 @@ class Bark(BaseTTS): if __name__ == "__main__": + # from TTS.tts.configs.bark_config import BarkConfig + + # bark_config = BarkConfig() + + # bark_config.training_mode = "semantic" + # bark_config.batch_size = 2 + + # bark = Bark.init_from_config(bark_config) + + # # batch = {"waveform": torch.randn(2, 48000), "raw_text": ["hello world", "how are you"]} + # # batch = bark.format_batch(batch) + # # batch = bark.format_batch_on_device(batch) + + # from trainer import Trainer, TrainerArgs + + # dataset_config = BaseDatasetConfig( + # formatter="ljspeech", meta_file_train="metadata.csv", path="/data/TTS-public/tests/data/ljspeech/" + # ) + + # train_samples, eval_samples = load_tts_samples( + # dataset_config, + # eval_split=True, + # eval_split_max_size=4, + # eval_split_size=4, + # ) + + # trainer = Trainer( + # model=bark, + # config=bark_config, + # output_path="./", + # args=TrainerArgs(), + # train_samples=train_samples, + # eval_samples=eval_samples, + # ) + # trainer.fit() + from TTS.tts.configs.bark_config import BarkConfig bark_config = BarkConfig() - bark_config.training_mode = "semantic" + bark_config.training_mode = "coarse" bark_config.batch_size = 2 + bark_config.run_eval = False + bark_config.save_checkpoints = False + bark_config.save_best_after = 100000 + bark_config.print_step = 1 bark = Bark.init_from_config(bark_config)