Update mel spectrogram for the style encoder

This commit is contained in:
Edresson Casanova 2023-10-11 19:04:18 -03:00
parent a32961bcb4
commit 40a4e631ea
1 changed files with 13 additions and 3 deletions

View File

@ -157,7 +157,17 @@ class GPTTrainer(BaseTTS):
print(">> GPT weights randomly initialized! If you want you can specify a checkpoint in config.model_args.gpt_checkpoint")
# Mel spectrogram extractor for conditioning
self.torch_mel_spectrogram = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.sample_rate)
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
filter_length=4096,
hop_length=1024,
win_length=4096,
normalize=False,
sampling_rate=config.audio.sample_rate,
mel_fmin=0,
mel_fmax=8000,
n_mel_channels=80,
mel_norm_file=self.args.mel_norm_file
)
# Load DVAE
self.dvae = DiscreteVAE(
@ -224,9 +234,9 @@ class GPTTrainer(BaseTTS):
# transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor
B, num_cond_samples, C, T = batch["conditioning"].size()
conditioning_reshaped = batch["conditioning"].view(B*num_cond_samples, C, T)
paired_conditioning_mel = self.torch_mel_spectrogram(conditioning_reshaped)
paired_conditioning_mel = self.torch_mel_spectrogram_style_encoder(conditioning_reshaped)
# transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel])
n_mel = self.torch_mel_spectrogram.n_mel_channels # paired_conditioning_mel.size(1)
n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1)
T_mel = paired_conditioning_mel.size(2)
paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel)
# get the conditioning embeddings