compute normalized logp using torch primitives

This commit is contained in:
Eren Gölge 2021-03-23 13:42:07 +01:00
parent 7a382a5c2b
commit 6b2e13bf62
2 changed files with 9 additions and 18 deletions

View File

@ -464,11 +464,6 @@ class MDNLoss(nn.Module):
'''
B, _, L = mu.size()
T = melspec.size(2)
# x = melspec.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D]
# mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
# log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
# exponential = -0.5*torch.mean(mse_loss_custom(x, mu)/torch.pow(log_sigma.exp(), 2), dim=-1) # B, L, T
# log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1)
log_alpha = logp.new_ones(B, L, T)*(-1e4)
log_alpha[:, 0, 0] = logp[:, 0, 0]
for t in range(1, T):

View File

@ -112,19 +112,15 @@ class AlignTTS(nn.Module):
@staticmethod
def compute_log_probs(mu, log_sigma, y):
'''Faster way to compute log probability'''
scale = torch.exp(-2 * log_sigma)
# [B, T_en, 1]
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - log_sigma,
[1]).unsqueeze(-1)
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
logp2 = torch.matmul(scale.transpose(1, 2), -0.5 * (y**2))
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
logp3 = torch.matmul((mu * scale).transpose(1, 2), y)
# [B, T_en, 1]
logp4 = torch.sum(-0.5 * (mu**2) * scale, [1]).unsqueeze(-1)
# [B, T_en, T_dec]
logp = logp1 + logp2 + logp3 + logp4
# pylint: disable=protected-access, c-extension-no-member
y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D]
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
expanded_y, expanded_mu = torch.broadcast_tensors(y, mu)
exponential = -0.5 * torch.mean(torch._C._nn.mse_loss(
expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2),
dim=-1) # B, L, T
logp = exponential - 0.5 * log_sigma.mean(dim=-1)
return logp
def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask):