Add test sentences during the training

This commit is contained in:
Edresson Casanova 2023-10-16 15:32:00 -03:00
parent 2f868dd5c2
commit c4ceaabe2c
3 changed files with 67 additions and 54 deletions

View File

@ -12,7 +12,8 @@ import sys
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig, Xtts
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from coqpit import Coqpit from coqpit import Coqpit
@ -25,20 +26,21 @@ from TTS.tts.datasets.dataset import TTSDataset
from trainer.torch import DistributedSampler from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.tts.layers.xtts.dvae import DiscreteVAE from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
@dataclass @dataclass
class GPTConfig(TortoiseConfig): class GPTTrainerConfig(XttsConfig):
lr: float = 5e-06 lr: float = 5e-06
training_seed: int = 1 training_seed: int = 1
optimizer_wd_only_on_weights: bool = False optimizer_wd_only_on_weights: bool = False
weighted_loss_attrs: dict = field(default_factory=lambda: {}) weighted_loss_attrs: dict = field(default_factory=lambda: {})
weighted_loss_multipliers: dict = field(default_factory=lambda: {}) weighted_loss_multipliers: dict = field(default_factory=lambda: {})
test_sentences: List[dict] = field(default_factory=lambda: [])
@dataclass @dataclass
class XttsAudioConfig(XttsAudioConfig): class XttsAudioConfig(XttsAudioConfig):
@ -58,7 +60,8 @@ class GPTArgs(XttsArgs):
tokenizer_file: str = "" tokenizer_file: str = ""
mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth" mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
dvae_checkpoint: str = "" dvae_checkpoint: str = ""
gpt_checkpoint: str = "" xtts_checkpoint: str = ""
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
vocoder: str = "" # overide vocoder key on the config to avoid json write issues vocoder: str = "" # overide vocoder key on the config to avoid json write issues
@ -80,28 +83,18 @@ class GPTTrainer(BaseTTS):
""" """
super().__init__(config, ap=None, tokenizer=None) super().__init__(config, ap=None, tokenizer=None)
self.config = config self.config = config
# init XTTS model
self.xtts = Xtts(self.config)
# create the tokenizer with the target vocabulary
self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
# init gpt encoder and hifigan decoder
self.xtts.init_models()
# set mel stats
if self.args.mel_norm_file:
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
self.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file) if self.args.xtts_checkpoint:
self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False)
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
self.gpt = GPT(
layers=self.args.gpt_layers,
model_dim=self.args.gpt_n_model_channels,
start_text_token=self.args.gpt_start_text_token,
stop_text_token=self.args.gpt_stop_text_token,
heads=self.args.gpt_n_heads,
max_text_tokens=self.args.gpt_max_text_tokens,
max_mel_tokens=self.args.gpt_max_audio_tokens,
max_prompt_tokens=self.args.gpt_max_prompt_tokens,
number_text_tokens=self.args.gpt_number_text_tokens,
num_audio_tokens=self.args.gpt_num_audio_tokens,
start_audio_token=self.args.gpt_start_audio_token,
stop_audio_token=self.args.gpt_stop_audio_token,
).cuda()
# load GPT if available # load GPT if available
if self.args.gpt_checkpoint: if self.args.gpt_checkpoint:
@ -122,8 +115,8 @@ class GPTTrainer(BaseTTS):
del gpt_checkpoint[key] del gpt_checkpoint[key]
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible # edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.gpt.text_embedding.weight.shape: if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape:
num_new_tokens = self.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] num_new_tokens = self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
# add new tokens to a linear layer (text_head) # add new tokens to a linear layer (text_head)
@ -137,7 +130,7 @@ class GPTTrainer(BaseTTS):
# add new weights to the linear layer (text_head) # add new weights to the linear layer (text_head)
text_head_weight = gpt_checkpoint["text_head.weight"] text_head_weight = gpt_checkpoint["text_head.weight"]
start_token_row = text_head_weight[-1, :] start_token_row = text_head_weight[-1, :]
new_entry = torch.randn(num_new_tokens, self.gpt.text_head.weight.shape[1]) new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1])
text_head_weight = torch.cat([text_head_weight, new_entry], axis=0) text_head_weight = torch.cat([text_head_weight, new_entry], axis=0)
text_head_weight[-1, :] = start_token_row text_head_weight[-1, :] = start_token_row
gpt_checkpoint["text_head.weight"] = text_head_weight gpt_checkpoint["text_head.weight"] = text_head_weight
@ -150,10 +143,8 @@ class GPTTrainer(BaseTTS):
text_head_bias[-1] = start_token_row text_head_bias[-1] = start_token_row
gpt_checkpoint["text_head.bias"] = text_head_bias gpt_checkpoint["text_head.bias"] = text_head_bias
self.gpt.load_state_dict(gpt_checkpoint, strict=True) self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True)
print(">> GPT weights restored from:", self.args.gpt_checkpoint) print(">> GPT weights restored from:", self.args.gpt_checkpoint)
else:
print(">> GPT weights randomly initialized! If you want you can specify a checkpoint in config.model_args.gpt_checkpoint")
# Mel spectrogram extractor for conditioning # Mel spectrogram extractor for conditioning
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram( self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
@ -195,6 +186,7 @@ class GPTTrainer(BaseTTS):
# Mel spectrogram extractor for DVAE # Mel spectrogram extractor for DVAE
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate) self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate)
@property @property
def device(self): def device(self):
return next(self.parameters()).device return next(self.parameters()).device
@ -211,12 +203,30 @@ class GPTTrainer(BaseTTS):
cond_mels: MEL float tensor, (b, num_samples, 80,t_m) cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
cond_idxs: cond start and end indexs, (b, 2) cond_idxs: cond start and end indexs, (b, 2)
""" """
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs) losses = self.xtts.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
return losses return losses
@torch.no_grad() @torch.no_grad()
def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
return {}, {} if self.config.test_sentences:
# init gpt for inference mode
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
self.xtts.gpt.eval()
test_audios = {}
print(" | > Synthesizing test sentences.")
for idx, s_info in enumerate(self.config.test_sentences):
wav = self.xtts.synthesize(s_info["text"], self.config, s_info["speaker_wav"], s_info["language"])["wav"]
test_audios["{}-audio".format(idx)] = wav
# delete inference layers
del self.xtts.gpt.gpt_inference
del self.xtts.gpt.gpt.wte
return {"audios": test_audios}
def test_log(
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate)
def format_batch(self, batch: Dict) -> Dict: def format_batch(self, batch: Dict) -> Dict:
return batch return batch
@ -323,7 +333,7 @@ class GPTTrainer(BaseTTS):
loader = None loader = None
else: else:
# init dataloader # init dataloader
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval) dataset = XTTSDataset(self.config, samples, self.xtts.tokenizer, config.audio.sample_rate, is_eval)
# wait all the DDP process to be ready # wait all the DDP process to be ready
if num_gpus > 1: if num_gpus > 1:
@ -362,7 +372,7 @@ class GPTTrainer(BaseTTS):
# ToDo: deal with multi GPU training # ToDo: deal with multi GPU training
if self.config.optimizer_wd_only_on_weights: if self.config.optimizer_wd_only_on_weights:
# parameters to only GPT model # parameters to only GPT model
net = self.gpt net = self.xtts.gpt
# normalizations # normalizations
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
@ -410,7 +420,7 @@ class GPTTrainer(BaseTTS):
self.config.optimizer_params, self.config.optimizer_params,
self.config.lr, self.config.lr,
# optimize only for the GPT model # optimize only for the GPT model
parameters=self.gpt.parameters(), parameters=self.xtts.gpt.parameters(),
) )
def get_scheduler(self, optimizer) -> List: def get_scheduler(self, optimizer) -> List:
@ -432,21 +442,21 @@ class GPTTrainer(BaseTTS):
target_options={"anon": True}, target_options={"anon": True},
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin ): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
"""Load the model checkpoint and setup for training or inference""" """Load the model checkpoint and setup for training or inference"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))["model"]
# load the model weights # load the model weights
self.gpt.load_state_dict(state, strict=strict) self.xtts.load_state_dict(state, strict=strict)
if eval: if eval:
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
self.eval() self.eval()
self.set_inference()
assert not self.training assert not self.training
@staticmethod @staticmethod
def init_from_config(config: "GPTConfig", samples: Union[List[List], List[Dict]] = None): def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config """Initiate model from config
Args: Args:
config (GPTConfig): Model config. config (GPTTrainerConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None. Defaults to None.
""" """

View File

@ -387,7 +387,7 @@ class Xtts(BaseTTS):
audio = load_audio(audio_path) audio = load_audio(audio_path)
audio = audio[:, : 22050 * length] audio = audio[:, : 22050 * length]
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu()) mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False) cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2) return cond_latent.transpose(1, 2)
@torch.inference_mode() @torch.inference_mode()

View File

@ -3,7 +3,7 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTConfig from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTTrainerConfig
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig( config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
@ -265,21 +265,21 @@ def main():
debug_loading_failures=False, debug_loading_failures=False,
max_wav_length=255995, # ~11.6 seconds max_wav_length=255995, # ~11.6 seconds
max_text_length=200, max_text_length=200,
tokenizer_file="/raid/datasets/xtts_models/vocab.json",
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth", mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth", dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
gpt_checkpoint="/raid/datasets/xtts_models/gpt.pth", tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", # checkpoint path of the model that you want to fine-tune
gpt_num_audio_tokens=8194, gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192, gpt_start_audio_token=8192,
gpt_stop_audio_token=8193, gpt_stop_audio_token=8193,
) )
audio_config = XttsAudioConfig( audio_config = XttsAudioConfig(
sample_rate=22050, # autoregressive SR sample_rate=22050, # GPT SR
dvae_sample_rate=22050, dvae_sample_rate=22050,
diffusion_sample_rate=24000, diffusion_sample_rate=24000,
output_sample_rate=24000 output_sample_rate=24000
) )
config = GPTConfig( config = GPTTrainerConfig(
output_path=OUT_PATH, output_path=OUT_PATH,
model_args=model_args, model_args=model_args,
run_name=RUN_NAME, run_name=RUN_NAME,
@ -313,6 +313,10 @@ def main():
lr_scheduler="MultiStepLR", lr_scheduler="MultiStepLR",
# it was adjusted accordly for the new step scheme # it was adjusted accordly for the new step scheme
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
test_sentences=[
{"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
{"text": "This cake is great. It's so delicious and moist.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
]
) )
# init the model from config # init the model from config
@ -341,7 +345,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
RUN_NAME = "GPT_XTTS" RUN_NAME = "GPT_XTTS"
PROJECT_NAME = "XTTS" PROJECT_NAME = "XTTS_trainer"
OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_style_emb/" OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_style_emb/"
DASHBOARD_LOGGER = "clearml" DASHBOARD_LOGGER = "clearml"
LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/" LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/"
@ -352,12 +356,11 @@ if __name__ == "__main__":
GRAD_ACUMM_STEPS = 28 GRAD_ACUMM_STEPS = 28
# debug # debug
DASHBOARD_LOGGER = "tensorboard" # DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None # LOGGER_URI = None
RESTORE_PATH = None # RESTORE_PATH = None
BATCH_SIZE = 10 BATCH_SIZE = 2
GRAD_ACUMM_STEPS = 1 GRAD_ACUMM_STEPS = 1
NUM_LOADERS = 1