mirror of https://github.com/coqui-ai/TTS.git
Recreate the prior distribution of Capacitron VAE on the right device
This commit is contained in:
parent
a822f21b78
commit
ec8c8dc5a2
|
@ -38,6 +38,7 @@ class CapacitronVAE(nn.Module):
|
||||||
# TODO: Test a multispeaker model!
|
# TODO: Test a multispeaker model!
|
||||||
mlp_input_dimension += speaker_embedding_dim
|
mlp_input_dimension += speaker_embedding_dim
|
||||||
self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim)
|
self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim)
|
||||||
|
self.prior_converted_to_device = False
|
||||||
|
|
||||||
def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None):
|
def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None):
|
||||||
# Use reference
|
# Use reference
|
||||||
|
@ -59,8 +60,14 @@ class CapacitronVAE(nn.Module):
|
||||||
# an MLP to produce the parameteres for the approximate poterior distributions
|
# an MLP to produce the parameteres for the approximate poterior distributions
|
||||||
mu, sigma = self.post_encoder_mlp(enc_out)
|
mu, sigma = self.post_encoder_mlp(enc_out)
|
||||||
# convert to cpu because prior_distribution was created on cpu
|
# convert to cpu because prior_distribution was created on cpu
|
||||||
mu = mu.cpu()
|
# mu = mu.cpu()
|
||||||
sigma = sigma.cpu()
|
# sigma = sigma.cpu()
|
||||||
|
# recreate the prior distribution on the right device
|
||||||
|
if not self.prior_converted_to_device:
|
||||||
|
self.prior_distribution = MVN(
|
||||||
|
torch.zeros(mu.size(1)).to(mu.device), torch.eye(mu.size(1)).to(mu.device)
|
||||||
|
)
|
||||||
|
self.prior_converted_to_device = True
|
||||||
|
|
||||||
# Sample from the posterior: z ~ q(z|x)
|
# Sample from the posterior: z ~ q(z|x)
|
||||||
self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma))
|
self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma))
|
||||||
|
|
|
@ -52,6 +52,8 @@ config.model_args.use_prosody_encoder_z_p_input = True
|
||||||
config.model_args.prosody_encoder_type = "vae"
|
config.model_args.prosody_encoder_type = "vae"
|
||||||
config.model_args.detach_prosody_enc_input = True
|
config.model_args.detach_prosody_enc_input = True
|
||||||
|
|
||||||
|
config.mixed_precision = False
|
||||||
|
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
||||||
# train the model for one epoch
|
# train the model for one epoch
|
||||||
|
|
Loading…
Reference in New Issue