linter fixes

This commit is contained in:
Eren Gölge 2021-03-23 16:43:49 +01:00
parent 6b2e13bf62
commit a3a840fd78
4 changed files with 15 additions and 20 deletions

View File

@ -46,7 +46,7 @@ jobs:
python3 setup.py egg_info
- name: Lint check
run: |
cardboardlinter -n auto
cardboardlinter
- name: Unit tests
run: nosetests tests --nocapture --processes=0 --process-timeout=20 --process-restartworker
- name: Test scripts

View File

@ -140,7 +140,7 @@ if __name__ == '__main__':
# forward pass model
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp = model.forward(
decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward(
text_input,
text_lengths,
mel_targets,
@ -149,9 +149,7 @@ if __name__ == '__main__':
phase=training_phase)
# compute loss
loss_dict = criterion(mu,
log_sigma,
logp,
loss_dict = criterion(logp,
decoder_output,
mel_targets,
mel_lengths,

View File

@ -455,22 +455,22 @@ class MDNLoss(nn.Module):
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.
"""
def forward(self, mu, log_sigma, logp, melspec, text_lengths, mel_lengths): # pylint: disable=no-self-use
def forward(self, logp, text_lengths, mel_lengths): # pylint: disable=no-self-use
'''
Shapes:
mu: [B, D, T]
log_sigma: [B, D, T]
mel_spec: [B, D, T]
'''
B, _, L = mu.size()
T = melspec.size(2)
log_alpha = logp.new_ones(B, L, T)*(-1e4)
B, T_seq, T_mel = logp.shape
log_alpha = logp.new_ones(B, T_seq, T_mel)*(-1e4)
log_alpha[:, 0, 0] = logp[:, 0, 0]
for t in range(1, T):
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t], (0, 0, 1, -1), value=-1e4)], dim=-1)
for t in range(1, T_mel):
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t],
(0, 0, 1, -1), value=-1e4)], dim=-1)
log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t]
alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1]
mdn_loss = -alpha_last.mean() / L
mdn_loss = -alpha_last.mean() / T_seq
return mdn_loss#, log_prob_matrix
@ -499,25 +499,24 @@ class AlignTTSLoss(nn.Module):
self.spec_loss_alpha = c.spec_loss_alpha
self.mdn_alpha = c.mdn_alpha
def forward(self, mu, log_sigma, logp, decoder_output, decoder_target,
decoder_output_lens, dur_output, dur_target, input_lens, step,
phase):
def forward(self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target,
input_lens, step, phase):
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(
step)
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
if phase == 0:
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
elif phase == 1:
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
elif phase == 2:
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
spec_loss = self.spec_lossX(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
elif phase == 3:
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
else:
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)

View File

@ -1,5 +1,3 @@
import math
import torch
import torch.nn as nn
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding