mirror of https://github.com/coqui-ai/TTS.git
masked loss
This commit is contained in:
parent
b10abadada
commit
96efe83a41
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import functional
|
from torch.nn import functional
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
|
||||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||||
|
@ -7,7 +8,7 @@ def _sequence_mask(sequence_length, max_len=None):
|
||||||
if max_len is None:
|
if max_len is None:
|
||||||
max_len = sequence_length.data.max()
|
max_len = sequence_length.data.max()
|
||||||
batch_size = sequence_length.size(0)
|
batch_size = sequence_length.size(0)
|
||||||
seq_range = torch.range(0, max_len - 1).long()
|
seq_range = torch.arange(0, max_len).long()
|
||||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||||
seq_range_expand = Variable(seq_range_expand)
|
seq_range_expand = Variable(seq_range_expand)
|
||||||
if sequence_length.is_cuda:
|
if sequence_length.is_cuda:
|
||||||
|
@ -31,18 +32,20 @@ def L1LossMasked(input, target, length):
|
||||||
Returns:
|
Returns:
|
||||||
loss: An average loss value masked by the length.
|
loss: An average loss value masked by the length.
|
||||||
"""
|
"""
|
||||||
|
input = input.contiguous()
|
||||||
|
target = target.contiguous()
|
||||||
|
|
||||||
# logits_flat: (batch * max_len, num_classes)
|
# logits_flat: (batch * max_len, dim)
|
||||||
input = input.view(-1, input.size(-1))
|
input = input.view(-1, input.size(-1))
|
||||||
# target_flat: (batch * max_len, 1)
|
# target_flat: (batch * max_len, dim)
|
||||||
target_flat = target.view(-1, 1)
|
target_flat = target.view(-1, 1)
|
||||||
# losses_flat: (batch * max_len, 1)
|
# losses_flat: (batch * max_len, dim)
|
||||||
losees_flat = functional.l1_loss(input, target, size_average=False,
|
losses_flat = functional.l1_loss(input, target, size_average=False,
|
||||||
reduce=False)
|
reduce=False)
|
||||||
# losses: (batch, max_len)
|
# losses: (batch, max_len)
|
||||||
losses = losses_flat.view(*target.size())
|
losses = losses_flat.view(*target.size())
|
||||||
# mask: (batch, max_len)
|
# mask: (batch, max_len)
|
||||||
mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
|
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||||
losses = losses * mask.float()
|
losses = losses * mask.float()
|
||||||
loss = losses.sum() / length.float().sum()
|
loss = losses.sum() / length.float().sum()
|
||||||
return loss
|
return loss / input.shape[0]
|
20
train.py
20
train.py
|
@ -26,7 +26,7 @@ from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from losses import
|
from layers.losses import L1LossMasked
|
||||||
|
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
@ -95,7 +95,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
# convert inputs to variables
|
# convert inputs to variables
|
||||||
text_input_var = Variable(text_input)
|
text_input_var = Variable(text_input)
|
||||||
mel_spec_var = Variable(mel_input)
|
mel_spec_var = Variable(mel_input)
|
||||||
mel_length_var = Variable(mel_lengths)
|
mel_lengths_var = Variable(mel_lengths)
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
|
|
||||||
# sort sequence by length for curriculum learning
|
# sort sequence by length for curriculum learning
|
||||||
|
@ -105,6 +105,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
sorted_lengths = sorted_lengths.long().numpy()
|
sorted_lengths = sorted_lengths.long().numpy()
|
||||||
text_input_var = text_input_var[indices]
|
text_input_var = text_input_var[indices]
|
||||||
mel_spec_var = mel_spec_var[indices]
|
mel_spec_var = mel_spec_var[indices]
|
||||||
|
mel_lengths_var = mel_lengths_var[indices]
|
||||||
linear_spec_var = linear_spec_var[indices]
|
linear_spec_var = linear_spec_var[indices]
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
|
@ -119,11 +120,11 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
model.forward(text_input_var, mel_spec_var)
|
model.forward(text_input_var, mel_spec_var)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
|
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[: ,: ,:n_priority_freq],
|
linear_spec_var[: ,: ,:n_priority_freq],
|
||||||
mel_lengths)
|
mel_lengths_var)
|
||||||
loss = mel_loss + linear_loss
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
# backpass and check the grad norm
|
# backpass and check the grad norm
|
||||||
|
@ -240,10 +241,10 @@ def evaluate(model, criterion, data_loader, current_step):
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
|
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[: ,: ,:n_priority_freq],
|
linear_spec_var[: ,: ,:n_priority_freq],
|
||||||
mel_lengths)
|
mel_lengths_var)
|
||||||
loss = mel_loss + linear_loss
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
|
@ -348,10 +349,7 @@ def main(args):
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
|
|
||||||
if use_cuda:
|
criterion = L1LossMasked
|
||||||
criterion = nn.L1Loss().cuda()
|
|
||||||
else:
|
|
||||||
criterion = nn.L1Loss()
|
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
|
|
Loading…
Reference in New Issue