mirror of https://github.com/coqui-ai/TTS.git
BCE masked loss and padding stop_tokens with 0s not 1s
This commit is contained in:
parent
a5d919cec8
commit
a68012aec2
|
@ -96,3 +96,32 @@ class AttentionEntropyLoss(nn.Module):
|
||||||
entropy = torch.distributions.Categorical(probs=align).entropy()
|
entropy = torch.distributions.Categorical(probs=align).entropy()
|
||||||
loss = (entropy / np.log(align.shape[1])).mean()
|
loss = (entropy / np.log(align.shape[1])).mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class BCELossMasked(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, pos_weight):
|
||||||
|
super(BCELossMasked, self).__init__()
|
||||||
|
self.pos_weight = pos_weight
|
||||||
|
|
||||||
|
def forward(self, x, target, length):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: A Variable containing a FloatTensor of size
|
||||||
|
(batch, max_len) which contains the
|
||||||
|
unnormalized probability for each class.
|
||||||
|
target: A Variable containing a LongTensor of size
|
||||||
|
(batch, max_len) which contains the index of the true
|
||||||
|
class for each corresponding step.
|
||||||
|
length: A Variable containing a LongTensor of size (batch,)
|
||||||
|
which contains the length of each data in a batch.
|
||||||
|
Returns:
|
||||||
|
loss: An average loss value in range [0, 1] masked by the length.
|
||||||
|
"""
|
||||||
|
# mask: (batch, max_len, 1)
|
||||||
|
target.requires_grad = False
|
||||||
|
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
|
||||||
|
loss = functional.binary_cross_entropy_with_logits(
|
||||||
|
x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum')
|
||||||
|
loss = loss / mask.sum()
|
||||||
|
return loss
|
||||||
|
|
8
train.py
8
train.py
|
@ -13,7 +13,7 @@ from torch.utils.data import DataLoader
|
||||||
from TTS.datasets.TTSDataset import MyDataset
|
from TTS.datasets.TTSDataset import MyDataset
|
||||||
from distribute import (DistributedSampler, apply_gradient_allreduce,
|
from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||||
init_distributed, reduce_tensor)
|
init_distributed, reduce_tensor)
|
||||||
from TTS.layers.losses import L1LossMasked, MSELossMasked
|
from TTS.layers.losses import L1LossMasked, MSELossMasked, BCELossMasked
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import (
|
from TTS.utils.generic_utils import (
|
||||||
NoamLR, check_update, count_parameters, create_experiment_folder,
|
NoamLR, check_update, count_parameters, create_experiment_folder,
|
||||||
|
@ -167,7 +167,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens,
|
stop_loss = criterion_st(stop_tokens,
|
||||||
stop_targets) if c.stopnet else torch.zeros(1)
|
stop_targets, mel_lengths) if c.stopnet else torch.zeros(1)
|
||||||
if c.loss_masking:
|
if c.loss_masking:
|
||||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
|
@ -365,7 +365,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(
|
stop_loss = criterion_st(
|
||||||
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
stop_tokens, stop_targets, mel_lengths) if c.stopnet else torch.zeros(1)
|
||||||
if c.loss_masking:
|
if c.loss_masking:
|
||||||
decoder_loss = criterion(decoder_output, mel_input,
|
decoder_loss = criterion(decoder_output, mel_input,
|
||||||
mel_lengths)
|
mel_lengths)
|
||||||
|
@ -571,7 +571,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
else:
|
else:
|
||||||
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
|
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
|
||||||
] else nn.MSELoss()
|
] else nn.MSELoss()
|
||||||
criterion_st = nn.BCEWithLogitsLoss(
|
criterion_st = BCELossMasked(
|
||||||
pos_weight=torch.tensor(10)) if c.stopnet else None
|
pos_weight=torch.tensor(10)) if c.stopnet else None
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
|
|
|
@ -31,7 +31,7 @@ def prepare_tensor(inputs, out_steps):
|
||||||
|
|
||||||
|
|
||||||
def _pad_stop_target(x, length):
|
def _pad_stop_target(x, length):
|
||||||
_pad = 1.
|
_pad = 0.
|
||||||
assert x.ndim == 1
|
assert x.ndim == 1
|
||||||
return np.pad(
|
return np.pad(
|
||||||
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||||
|
|
Loading…
Reference in New Issue