mirror of https://github.com/coqui-ai/TTS.git
compute normalized logp using torch primitives
This commit is contained in:
parent
7a382a5c2b
commit
6b2e13bf62
|
@ -464,11 +464,6 @@ class MDNLoss(nn.Module):
|
|||
'''
|
||||
B, _, L = mu.size()
|
||||
T = melspec.size(2)
|
||||
# x = melspec.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D]
|
||||
# mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
# log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
# exponential = -0.5*torch.mean(mse_loss_custom(x, mu)/torch.pow(log_sigma.exp(), 2), dim=-1) # B, L, T
|
||||
# log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1)
|
||||
log_alpha = logp.new_ones(B, L, T)*(-1e4)
|
||||
log_alpha[:, 0, 0] = logp[:, 0, 0]
|
||||
for t in range(1, T):
|
||||
|
|
|
@ -112,19 +112,15 @@ class AlignTTS(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def compute_log_probs(mu, log_sigma, y):
|
||||
'''Faster way to compute log probability'''
|
||||
scale = torch.exp(-2 * log_sigma)
|
||||
# [B, T_en, 1]
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - log_sigma,
|
||||
[1]).unsqueeze(-1)
|
||||
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
|
||||
logp2 = torch.matmul(scale.transpose(1, 2), -0.5 * (y**2))
|
||||
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
|
||||
logp3 = torch.matmul((mu * scale).transpose(1, 2), y)
|
||||
# [B, T_en, 1]
|
||||
logp4 = torch.sum(-0.5 * (mu**2) * scale, [1]).unsqueeze(-1)
|
||||
# [B, T_en, T_dec]
|
||||
logp = logp1 + logp2 + logp3 + logp4
|
||||
# pylint: disable=protected-access, c-extension-no-member
|
||||
y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D]
|
||||
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
expanded_y, expanded_mu = torch.broadcast_tensors(y, mu)
|
||||
exponential = -0.5 * torch.mean(torch._C._nn.mse_loss(
|
||||
expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2),
|
||||
dim=-1) # B, L, T
|
||||
logp = exponential - 0.5 * log_sigma.mean(dim=-1)
|
||||
return logp
|
||||
|
||||
def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask):
|
||||
|
|
Loading…
Reference in New Issue