mirror of https://github.com/coqui-ai/TTS.git
linter fixes
This commit is contained in:
parent
e97bb45aba
commit
669a2e1d73
|
@ -126,7 +126,7 @@ class BCELossMasked(nn.Module):
|
||||||
loss = loss / mask.sum()
|
loss = loss / mask.sum()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class GuidedAttentionLoss(torch.nn.Module):
|
class GuidedAttentionLoss(torch.nn.Module):
|
||||||
def __init__(self, sigma=0.4):
|
def __init__(self, sigma=0.4):
|
||||||
super(GuidedAttentionLoss, self).__init__()
|
super(GuidedAttentionLoss, self).__init__()
|
||||||
|
@ -179,6 +179,7 @@ class TacotronLoss(torch.nn.Module):
|
||||||
if c.ga_alpha > 0:
|
if c.ga_alpha > 0:
|
||||||
self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma)
|
self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma)
|
||||||
# stopnet loss
|
# stopnet loss
|
||||||
|
# pylint: disable=not-callable
|
||||||
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
|
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
|
||||||
|
|
||||||
def forward(self, postnet_output, decoder_output, mel_input, linear_input,
|
def forward(self, postnet_output, decoder_output, mel_input, linear_input,
|
||||||
|
@ -213,7 +214,7 @@ class TacotronLoss(torch.nn.Module):
|
||||||
if not self.config.separate_stopnet and self.config.stopnet:
|
if not self.config.separate_stopnet and self.config.stopnet:
|
||||||
loss += stop_loss
|
loss += stop_loss
|
||||||
return_dict['stopnet_loss'] = stop_loss
|
return_dict['stopnet_loss'] = stop_loss
|
||||||
|
|
||||||
# backward decoder loss (if enabled)
|
# backward decoder loss (if enabled)
|
||||||
if self.config.bidirectional_decoder:
|
if self.config.bidirectional_decoder:
|
||||||
if self.config.loss_masking:
|
if self.config.loss_masking:
|
||||||
|
@ -224,13 +225,13 @@ class TacotronLoss(torch.nn.Module):
|
||||||
loss += decoder_b_loss + decoder_c_loss
|
loss += decoder_b_loss + decoder_c_loss
|
||||||
return_dict['decoder_b_loss'] = decoder_b_loss
|
return_dict['decoder_b_loss'] = decoder_b_loss
|
||||||
return_dict['decoder_c_loss'] = decoder_c_loss
|
return_dict['decoder_c_loss'] = decoder_c_loss
|
||||||
|
|
||||||
# guided attention loss (if enabled)
|
# guided attention loss (if enabled)
|
||||||
if self.config.ga_alpha > 0:
|
if self.config.ga_alpha > 0:
|
||||||
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
|
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
|
||||||
loss += ga_loss * self.ga_alpha
|
loss += ga_loss * self.ga_alpha
|
||||||
return_dict['ga_loss'] = ga_loss * self.ga_alpha
|
return_dict['ga_loss'] = ga_loss * self.ga_alpha
|
||||||
|
|
||||||
return_dict['loss'] = loss
|
return_dict['loss'] = loss
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
24
train.py
24
train.py
|
@ -214,7 +214,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} "
|
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} "
|
||||||
"LoaderTime:{:.2f} LR:{:.6f}".format(
|
"LoaderTime:{:.2f} LR:{:.6f}".format(
|
||||||
num_iter, batch_n_iter, global_step, loss_dict['postnet_loss'].item(),
|
num_iter, batch_n_iter, global_step, loss_dict['postnet_loss'].item(),
|
||||||
loss_dict['decoder_loss'].item(), loss_dict['stopnet_loss'].item(),
|
loss_dict['decoder_loss'].item(), loss_dict['stopnet_loss'].item(),
|
||||||
loss_dict['ga_loss'].item(), grad_norm, grad_norm_st, avg_text_length,
|
loss_dict['ga_loss'].item(), grad_norm, grad_norm_st, avg_text_length,
|
||||||
avg_spec_length, step_time, loader_time, current_lr),
|
avg_spec_length, step_time, loader_time, current_lr),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
@ -232,7 +232,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
'avg_postnet_loss': float(loss_dict['postnet_loss'].item()),
|
'avg_postnet_loss': float(loss_dict['postnet_loss'].item()),
|
||||||
'avg_decoder_loss': float(loss_dict['decoder_loss'].item()),
|
'avg_decoder_loss': float(loss_dict['decoder_loss'].item()),
|
||||||
'avg_stop_loss': loss_dict['stopnet_loss'].item()
|
'avg_stop_loss': loss_dict['stopnet_loss'].item()
|
||||||
if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
|
if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
|
||||||
'avg_step_time': step_time,
|
'avg_step_time': step_time,
|
||||||
'avg_loader_time': loader_time
|
'avg_loader_time': loader_time
|
||||||
}
|
}
|
||||||
|
@ -399,16 +399,16 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
|
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
|
||||||
"StopLoss: {:.5f} - {:.5f} GALoss: {:.5f} - {:.5f} AlignScore: {:.4f} - {:.4f}"
|
"StopLoss: {:.5f} - {:.5f} GALoss: {:.5f} - {:.5f} AlignScore: {:.4f} - {:.4f}"
|
||||||
.format(loss_dict['loss'].item(),
|
.format(loss_dict['loss'].item(),
|
||||||
loss_dict['postnet_loss'].item(),
|
loss_dict['postnet_loss'].item(),
|
||||||
keep_avg['avg_postnet_loss'],
|
keep_avg['avg_postnet_loss'],
|
||||||
loss_dict['decoder_loss'].item(),
|
loss_dict['decoder_loss'].item(),
|
||||||
keep_avg['avg_decoder_loss'],
|
keep_avg['avg_decoder_loss'],
|
||||||
loss_dict['stopnet_loss'].item(),
|
loss_dict['stopnet_loss'].item(),
|
||||||
keep_avg['avg_stop_loss'],
|
keep_avg['avg_stop_loss'],
|
||||||
loss_dict['ga_loss'].item(),
|
loss_dict['ga_loss'].item(),
|
||||||
keep_avg['avg_ga_loss'],
|
keep_avg['avg_ga_loss'],
|
||||||
align_score, keep_avg['avg_align_score']),
|
align_score, keep_avg['avg_align_score']),
|
||||||
flush=True)
|
flush=Tr ue)
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
|
|
|
@ -50,7 +50,7 @@ class AudioProcessor(object):
|
||||||
self.clip_norm = clip_norm
|
self.clip_norm = clip_norm
|
||||||
self.do_trim_silence = do_trim_silence
|
self.do_trim_silence = do_trim_silence
|
||||||
self.trim_db = trim_db
|
self.trim_db = trim_db
|
||||||
self.sound_norm = sound_norm
|
self.do_sound_norm = sound_norm
|
||||||
# setup stft parameters
|
# setup stft parameters
|
||||||
if hop_length is None:
|
if hop_length is None:
|
||||||
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
||||||
|
@ -232,17 +232,26 @@ class AudioProcessor(object):
|
||||||
return librosa.effects.trim(
|
return librosa.effects.trim(
|
||||||
wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[0]
|
wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[0]
|
||||||
|
|
||||||
def sound_norm(self, x):
|
@staticmethod
|
||||||
|
def sound_norm(x):
|
||||||
return x / abs(x).max() * 0.9
|
return x / abs(x).max() * 0.9
|
||||||
|
|
||||||
### save and load ###
|
### save and load ###
|
||||||
def load_wav(self, filename, sr=None):
|
def load_wav(self, filename, sr=None):
|
||||||
if sr is None:
|
if sr is None:
|
||||||
x, sr = sf.read(filename)
|
x, sr = sf.read(filename)
|
||||||
else:
|
else:
|
||||||
x, sr = librosa.load(filename, sr=sr)
|
x, sr = librosa.load(filename, sr=sr)
|
||||||
|
if self.do_trim_silence:
|
||||||
|
try:
|
||||||
|
x = self.trim_silence(x)
|
||||||
|
except ValueError:
|
||||||
|
print(f' [!] File cannot be trimmed for silence - {filename}')
|
||||||
|
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
|
||||||
|
if self.do_sound_norm:
|
||||||
|
x = self.sound_norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
def save_wav(self, wav, path):
|
||||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||||
scipy.io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16))
|
scipy.io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16))
|
||||||
|
@ -263,20 +272,6 @@ class AudioProcessor(object):
|
||||||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def load_wav(self, filename, sr=None):
|
|
||||||
if sr is None:
|
|
||||||
x, sr = sf.read(filename)
|
|
||||||
else:
|
|
||||||
x, sr = librosa.load(filename, sr=sr)
|
|
||||||
if self.do_trim_silence:
|
|
||||||
try:
|
|
||||||
x = self.trim_silence(x)
|
|
||||||
except ValueError:
|
|
||||||
print(f' [!] File cannot be trimmed for silence - {filename}')
|
|
||||||
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
|
|
||||||
if self.sound_norm:
|
|
||||||
x = x / abs(x).max() * 0.9
|
|
||||||
return x
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def encode_16bits(x):
|
def encode_16bits(x):
|
||||||
|
|
Loading…
Reference in New Issue