mirror of https://github.com/coqui-ai/TTS.git
masked loss
This commit is contained in:
parent
32d9c734b2
commit
2617518d91
|
@ -47,5 +47,5 @@ def L1LossMasked(input, target, length):
|
|||
# mask: (batch, max_len)
|
||||
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||
losses = losses * mask.float()
|
||||
loss = losses.sum() / length.float().sum()
|
||||
return loss / input.shape[0]
|
||||
loss = losses.sum() / (length.float().sum() * target.shape[2])
|
||||
return loss
|
2
train.py
2
train.py
|
@ -240,7 +240,7 @@ def evaluate(model, criterion, data_loader, current_step):
|
|||
mel_output, linear_output, alignments = model.forward(text_input_var, mel_spec_var)
|
||||
|
||||
# loss computation
|
||||
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
|
||||
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_spec_var[: ,: ,:n_priority_freq],
|
||||
|
|
Loading…
Reference in New Issue