From ec8c8dc5a29bf51df78a595204b582cb8b4c070f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 27 May 2022 16:41:13 -0300 Subject: [PATCH] Recreate the prior distribution of Capacitron VAE on the right device --- TTS/tts/layers/tacotron/capacitron_layers.py | 11 +++++++++-- .../test_vits_speaker_emb_with_prosody_encoder.py | 2 ++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/TTS/tts/layers/tacotron/capacitron_layers.py b/TTS/tts/layers/tacotron/capacitron_layers.py index 56fe44bc..5031bd28 100644 --- a/TTS/tts/layers/tacotron/capacitron_layers.py +++ b/TTS/tts/layers/tacotron/capacitron_layers.py @@ -38,6 +38,7 @@ class CapacitronVAE(nn.Module): # TODO: Test a multispeaker model! mlp_input_dimension += speaker_embedding_dim self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim) + self.prior_converted_to_device = False def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None): # Use reference @@ -59,8 +60,14 @@ class CapacitronVAE(nn.Module): # an MLP to produce the parameteres for the approximate poterior distributions mu, sigma = self.post_encoder_mlp(enc_out) # convert to cpu because prior_distribution was created on cpu - mu = mu.cpu() - sigma = sigma.cpu() + # mu = mu.cpu() + # sigma = sigma.cpu() + # recreate the prior distribution on the right device + if not self.prior_converted_to_device: + self.prior_distribution = MVN( + torch.zeros(mu.size(1)).to(mu.device), torch.eye(mu.size(1)).to(mu.device) + ) + self.prior_converted_to_device = True # Sample from the posterior: z ~ q(z|x) self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma)) diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index 3737a473..4fe3a8b1 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -52,6 +52,8 @@ config.model_args.use_prosody_encoder_z_p_input = True config.model_args.prosody_encoder_type = "vae" config.model_args.detach_prosody_enc_input = True +config.mixed_precision = False + config.save_json(config_path) # train the model for one epoch