mirror of https://github.com/coqui-ai/TTS.git
linter fixes
This commit is contained in:
parent
6b2e13bf62
commit
a3a840fd78
|
@ -46,7 +46,7 @@ jobs:
|
||||||
python3 setup.py egg_info
|
python3 setup.py egg_info
|
||||||
- name: Lint check
|
- name: Lint check
|
||||||
run: |
|
run: |
|
||||||
cardboardlinter -n auto
|
cardboardlinter
|
||||||
- name: Unit tests
|
- name: Unit tests
|
||||||
run: nosetests tests --nocapture --processes=0 --process-timeout=20 --process-restartworker
|
run: nosetests tests --nocapture --processes=0 --process-timeout=20 --process-restartworker
|
||||||
- name: Test scripts
|
- name: Test scripts
|
||||||
|
|
|
@ -140,7 +140,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||||
decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp = model.forward(
|
decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward(
|
||||||
text_input,
|
text_input,
|
||||||
text_lengths,
|
text_lengths,
|
||||||
mel_targets,
|
mel_targets,
|
||||||
|
@ -149,9 +149,7 @@ if __name__ == '__main__':
|
||||||
phase=training_phase)
|
phase=training_phase)
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
loss_dict = criterion(mu,
|
loss_dict = criterion(logp,
|
||||||
log_sigma,
|
|
||||||
logp,
|
|
||||||
decoder_output,
|
decoder_output,
|
||||||
mel_targets,
|
mel_targets,
|
||||||
mel_lengths,
|
mel_lengths,
|
||||||
|
|
|
@ -455,22 +455,22 @@ class MDNLoss(nn.Module):
|
||||||
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.
|
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, mu, log_sigma, logp, melspec, text_lengths, mel_lengths): # pylint: disable=no-self-use
|
def forward(self, logp, text_lengths, mel_lengths): # pylint: disable=no-self-use
|
||||||
'''
|
'''
|
||||||
Shapes:
|
Shapes:
|
||||||
mu: [B, D, T]
|
mu: [B, D, T]
|
||||||
log_sigma: [B, D, T]
|
log_sigma: [B, D, T]
|
||||||
mel_spec: [B, D, T]
|
mel_spec: [B, D, T]
|
||||||
'''
|
'''
|
||||||
B, _, L = mu.size()
|
B, T_seq, T_mel = logp.shape
|
||||||
T = melspec.size(2)
|
log_alpha = logp.new_ones(B, T_seq, T_mel)*(-1e4)
|
||||||
log_alpha = logp.new_ones(B, L, T)*(-1e4)
|
|
||||||
log_alpha[:, 0, 0] = logp[:, 0, 0]
|
log_alpha[:, 0, 0] = logp[:, 0, 0]
|
||||||
for t in range(1, T):
|
for t in range(1, T_mel):
|
||||||
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t], (0, 0, 1, -1), value=-1e4)], dim=-1)
|
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t],
|
||||||
|
(0, 0, 1, -1), value=-1e4)], dim=-1)
|
||||||
log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t]
|
log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t]
|
||||||
alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1]
|
alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1]
|
||||||
mdn_loss = -alpha_last.mean() / L
|
mdn_loss = -alpha_last.mean() / T_seq
|
||||||
return mdn_loss#, log_prob_matrix
|
return mdn_loss#, log_prob_matrix
|
||||||
|
|
||||||
|
|
||||||
|
@ -499,25 +499,24 @@ class AlignTTSLoss(nn.Module):
|
||||||
self.spec_loss_alpha = c.spec_loss_alpha
|
self.spec_loss_alpha = c.spec_loss_alpha
|
||||||
self.mdn_alpha = c.mdn_alpha
|
self.mdn_alpha = c.mdn_alpha
|
||||||
|
|
||||||
def forward(self, mu, log_sigma, logp, decoder_output, decoder_target,
|
def forward(self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target,
|
||||||
decoder_output_lens, dur_output, dur_target, input_lens, step,
|
input_lens, step, phase):
|
||||||
phase):
|
|
||||||
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(
|
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(
|
||||||
step)
|
step)
|
||||||
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
|
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
|
||||||
if phase == 0:
|
if phase == 0:
|
||||||
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
|
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
|
||||||
elif phase == 1:
|
elif phase == 1:
|
||||||
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
||||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||||
elif phase == 2:
|
elif phase == 2:
|
||||||
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
|
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
|
||||||
spec_loss = self.spec_lossX(decoder_output, decoder_target, decoder_output_lens)
|
spec_loss = self.spec_lossX(decoder_output, decoder_target, decoder_output_lens)
|
||||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||||
elif phase == 3:
|
elif phase == 3:
|
||||||
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
||||||
else:
|
else:
|
||||||
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
|
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
|
||||||
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
||||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||||
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
|
|
Loading…
Reference in New Issue