mirror of https://github.com/coqui-ai/TTS.git
bug fix loss
This commit is contained in:
parent
52b4bc6bed
commit
e257bd7278
|
@ -44,10 +44,11 @@ class L1LossMasked(nn.Module):
|
||||||
# target_flat: (batch * max_len, dim)
|
# target_flat: (batch * max_len, dim)
|
||||||
target_flat = target.view(-1, target.shape[-1])
|
target_flat = target.view(-1, target.shape[-1])
|
||||||
# losses_flat: (batch * max_len, dim)
|
# losses_flat: (batch * max_len, dim)
|
||||||
losses_flat = functional.l1_loss(input, target, size_average=False,
|
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
||||||
reduce=False)
|
reduce=False)
|
||||||
# losses: (batch, max_len, dim)
|
# losses: (batch, max_len, dim)
|
||||||
losses = losses_flat.view(*target.size())
|
losses = losses_flat.view(*target.size())
|
||||||
|
|
||||||
# mask: (batch, max_len, 1)
|
# mask: (batch, max_len, 1)
|
||||||
mask = _sequence_mask(sequence_length=length,
|
mask = _sequence_mask(sequence_length=length,
|
||||||
max_len=target.size(1)).unsqueeze(2)
|
max_len=target.size(1)).unsqueeze(2)
|
||||||
|
|
Loading…
Reference in New Issue