Recreate the prior distribution of Capacitron VAE on the right device

This commit is contained in:
Edresson Casanova 2022-05-27 16:41:13 -03:00
parent a822f21b78
commit ec8c8dc5a2
2 changed files with 11 additions and 2 deletions

View File

@ -38,6 +38,7 @@ class CapacitronVAE(nn.Module):
# TODO: Test a multispeaker model!
mlp_input_dimension += speaker_embedding_dim
self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim)
self.prior_converted_to_device = False
def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None):
# Use reference
@ -59,8 +60,14 @@ class CapacitronVAE(nn.Module):
# an MLP to produce the parameteres for the approximate poterior distributions
mu, sigma = self.post_encoder_mlp(enc_out)
# convert to cpu because prior_distribution was created on cpu
mu = mu.cpu()
sigma = sigma.cpu()
# mu = mu.cpu()
# sigma = sigma.cpu()
# recreate the prior distribution on the right device
if not self.prior_converted_to_device:
self.prior_distribution = MVN(
torch.zeros(mu.size(1)).to(mu.device), torch.eye(mu.size(1)).to(mu.device)
)
self.prior_converted_to_device = True
# Sample from the posterior: z ~ q(z|x)
self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma))

View File

@ -52,6 +52,8 @@ config.model_args.use_prosody_encoder_z_p_input = True
config.model_args.prosody_encoder_type = "vae"
config.model_args.detach_prosody_enc_input = True
config.mixed_precision = False
config.save_json(config_path)
# train the model for one epoch