diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index ccf34165..c6fdbebb 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -464,11 +464,6 @@ class MDNLoss(nn.Module): ''' B, _, L = mu.size() T = melspec.size(2) - # x = melspec.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D] - # mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] - # log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] - # exponential = -0.5*torch.mean(mse_loss_custom(x, mu)/torch.pow(log_sigma.exp(), 2), dim=-1) # B, L, T - # log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1) log_alpha = logp.new_ones(B, L, T)*(-1e4) log_alpha[:, 0, 0] = logp[:, 0, 0] for t in range(1, T): diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index a36a8ab9..08481e57 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -112,19 +112,15 @@ class AlignTTS(nn.Module): @staticmethod def compute_log_probs(mu, log_sigma, y): - '''Faster way to compute log probability''' - scale = torch.exp(-2 * log_sigma) - # [B, T_en, 1] - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - log_sigma, - [1]).unsqueeze(-1) - # [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec] - logp2 = torch.matmul(scale.transpose(1, 2), -0.5 * (y**2)) - # [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec] - logp3 = torch.matmul((mu * scale).transpose(1, 2), y) - # [B, T_en, 1] - logp4 = torch.sum(-0.5 * (mu**2) * scale, [1]).unsqueeze(-1) - # [B, T_en, T_dec] - logp = logp1 + logp2 + logp3 + logp4 + # pylint: disable=protected-access, c-extension-no-member + y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D] + mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + expanded_y, expanded_mu = torch.broadcast_tensors(y, mu) + exponential = -0.5 * torch.mean(torch._C._nn.mse_loss( + expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), + dim=-1) # B, L, T + logp = exponential - 0.5 * log_sigma.mean(dim=-1) return logp def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask):