diff --git a/layers/losses.py b/layers/losses.py index 176e2f09..7e5671b2 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -96,3 +96,32 @@ class AttentionEntropyLoss(nn.Module): entropy = torch.distributions.Categorical(probs=align).entropy() loss = (entropy / np.log(align.shape[1])).mean() return loss + + +class BCELossMasked(nn.Module): + + def __init__(self, pos_weight): + super(BCELossMasked, self).__init__() + self.pos_weight = pos_weight + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + # mask: (batch, max_len, 1) + target.requires_grad = False + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() + loss = functional.binary_cross_entropy_with_logits( + x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum') + loss = loss / mask.sum() + return loss diff --git a/train.py b/train.py index 4cf366e3..1397b310 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader from TTS.datasets.TTSDataset import MyDataset from distribute import (DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor) -from TTS.layers.losses import L1LossMasked, MSELossMasked +from TTS.layers.losses import L1LossMasked, MSELossMasked, BCELossMasked from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import ( NoamLR, check_update, count_parameters, create_experiment_folder, @@ -167,7 +167,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # loss computation stop_loss = criterion_st(stop_tokens, - stop_targets) if c.stopnet else torch.zeros(1) + stop_targets, mel_lengths) if c.stopnet else torch.zeros(1) if c.loss_masking: decoder_loss = criterion(decoder_output, mel_input, mel_lengths) if c.model in ["Tacotron", "TacotronGST"]: @@ -365,7 +365,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): # loss computation stop_loss = criterion_st( - stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) + stop_tokens, stop_targets, mel_lengths) if c.stopnet else torch.zeros(1) if c.loss_masking: decoder_loss = criterion(decoder_output, mel_input, mel_lengths) @@ -571,7 +571,7 @@ def main(args): # pylint: disable=redefined-outer-name else: criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST" ] else nn.MSELoss() - criterion_st = nn.BCEWithLogitsLoss( + criterion_st = BCELossMasked( pos_weight=torch.tensor(10)) if c.stopnet else None if args.restore_path: diff --git a/utils/data.py b/utils/data.py index a8b87cb5..f2d7538a 100644 --- a/utils/data.py +++ b/utils/data.py @@ -31,7 +31,7 @@ def prepare_tensor(inputs, out_steps): def _pad_stop_target(x, length): - _pad = 1. + _pad = 0. assert x.ndim == 1 return np.pad( x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)