mirror of https://github.com/coqui-ai/TTS.git
Recreate the prior distribution of Capacitron VAE on the right device
This commit is contained in:
parent
a822f21b78
commit
ec8c8dc5a2
|
@ -38,6 +38,7 @@ class CapacitronVAE(nn.Module):
|
|||
# TODO: Test a multispeaker model!
|
||||
mlp_input_dimension += speaker_embedding_dim
|
||||
self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim)
|
||||
self.prior_converted_to_device = False
|
||||
|
||||
def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None):
|
||||
# Use reference
|
||||
|
@ -59,8 +60,14 @@ class CapacitronVAE(nn.Module):
|
|||
# an MLP to produce the parameteres for the approximate poterior distributions
|
||||
mu, sigma = self.post_encoder_mlp(enc_out)
|
||||
# convert to cpu because prior_distribution was created on cpu
|
||||
mu = mu.cpu()
|
||||
sigma = sigma.cpu()
|
||||
# mu = mu.cpu()
|
||||
# sigma = sigma.cpu()
|
||||
# recreate the prior distribution on the right device
|
||||
if not self.prior_converted_to_device:
|
||||
self.prior_distribution = MVN(
|
||||
torch.zeros(mu.size(1)).to(mu.device), torch.eye(mu.size(1)).to(mu.device)
|
||||
)
|
||||
self.prior_converted_to_device = True
|
||||
|
||||
# Sample from the posterior: z ~ q(z|x)
|
||||
self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma))
|
||||
|
|
|
@ -52,6 +52,8 @@ config.model_args.use_prosody_encoder_z_p_input = True
|
|||
config.model_args.prosody_encoder_type = "vae"
|
||||
config.model_args.detach_prosody_enc_input = True
|
||||
|
||||
config.mixed_precision = False
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
|
|
Loading…
Reference in New Issue