mirror of https://github.com/coqui-ai/TTS.git
linter fixes
This commit is contained in:
parent
6b2e13bf62
commit
a3a840fd78
|
@ -46,7 +46,7 @@ jobs:
|
|||
python3 setup.py egg_info
|
||||
- name: Lint check
|
||||
run: |
|
||||
cardboardlinter -n auto
|
||||
cardboardlinter
|
||||
- name: Unit tests
|
||||
run: nosetests tests --nocapture --processes=0 --process-timeout=20 --process-restartworker
|
||||
- name: Test scripts
|
||||
|
|
|
@ -140,7 +140,7 @@ if __name__ == '__main__':
|
|||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp = model.forward(
|
||||
decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_targets,
|
||||
|
@ -149,9 +149,7 @@ if __name__ == '__main__':
|
|||
phase=training_phase)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(mu,
|
||||
log_sigma,
|
||||
logp,
|
||||
loss_dict = criterion(logp,
|
||||
decoder_output,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
|
|
|
@ -455,22 +455,22 @@ class MDNLoss(nn.Module):
|
|||
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.
|
||||
"""
|
||||
|
||||
def forward(self, mu, log_sigma, logp, melspec, text_lengths, mel_lengths): # pylint: disable=no-self-use
|
||||
def forward(self, logp, text_lengths, mel_lengths): # pylint: disable=no-self-use
|
||||
'''
|
||||
Shapes:
|
||||
mu: [B, D, T]
|
||||
log_sigma: [B, D, T]
|
||||
mel_spec: [B, D, T]
|
||||
'''
|
||||
B, _, L = mu.size()
|
||||
T = melspec.size(2)
|
||||
log_alpha = logp.new_ones(B, L, T)*(-1e4)
|
||||
B, T_seq, T_mel = logp.shape
|
||||
log_alpha = logp.new_ones(B, T_seq, T_mel)*(-1e4)
|
||||
log_alpha[:, 0, 0] = logp[:, 0, 0]
|
||||
for t in range(1, T):
|
||||
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t], (0, 0, 1, -1), value=-1e4)], dim=-1)
|
||||
for t in range(1, T_mel):
|
||||
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t],
|
||||
(0, 0, 1, -1), value=-1e4)], dim=-1)
|
||||
log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t]
|
||||
alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1]
|
||||
mdn_loss = -alpha_last.mean() / L
|
||||
mdn_loss = -alpha_last.mean() / T_seq
|
||||
return mdn_loss#, log_prob_matrix
|
||||
|
||||
|
||||
|
@ -499,25 +499,24 @@ class AlignTTSLoss(nn.Module):
|
|||
self.spec_loss_alpha = c.spec_loss_alpha
|
||||
self.mdn_alpha = c.mdn_alpha
|
||||
|
||||
def forward(self, mu, log_sigma, logp, decoder_output, decoder_target,
|
||||
decoder_output_lens, dur_output, dur_target, input_lens, step,
|
||||
phase):
|
||||
def forward(self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target,
|
||||
input_lens, step, phase):
|
||||
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(
|
||||
step)
|
||||
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
|
||||
if phase == 0:
|
||||
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
|
||||
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
|
||||
elif phase == 1:
|
||||
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
elif phase == 2:
|
||||
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
|
||||
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
|
||||
spec_loss = self.spec_lossX(decoder_output, decoder_target, decoder_output_lens)
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
elif phase == 3:
|
||||
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
||||
else:
|
||||
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
|
||||
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
|
||||
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
|
|
Loading…
Reference in New Issue