From 40a4e631ea9586f24a0f8758a626ffd40bfc2bc2 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 11 Oct 2023 19:04:18 -0300 Subject: [PATCH] Update mel spectrogram for the style encoder --- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index f73aeb05..6494f336 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -157,7 +157,17 @@ class GPTTrainer(BaseTTS): print(">> GPT weights randomly initialized! If you want you can specify a checkpoint in config.model_args.gpt_checkpoint") # Mel spectrogram extractor for conditioning - self.torch_mel_spectrogram = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.sample_rate) + self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram( + filter_length=4096, + hop_length=1024, + win_length=4096, + normalize=False, + sampling_rate=config.audio.sample_rate, + mel_fmin=0, + mel_fmax=8000, + n_mel_channels=80, + mel_norm_file=self.args.mel_norm_file + ) # Load DVAE self.dvae = DiscreteVAE( @@ -224,9 +234,9 @@ class GPTTrainer(BaseTTS): # transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor B, num_cond_samples, C, T = batch["conditioning"].size() conditioning_reshaped = batch["conditioning"].view(B*num_cond_samples, C, T) - paired_conditioning_mel = self.torch_mel_spectrogram(conditioning_reshaped) + paired_conditioning_mel = self.torch_mel_spectrogram_style_encoder(conditioning_reshaped) # transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel]) - n_mel = self.torch_mel_spectrogram.n_mel_channels # paired_conditioning_mel.size(1) + n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1) T_mel = paired_conditioning_mel.size(2) paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel) # get the conditioning embeddings