mirror of https://github.com/coqui-ai/TTS.git
bug fix train.py add stop token
This commit is contained in:
parent
8ed9f57a6d
commit
5a5b9263e2
22
train.py
22
train.py
|
@ -77,6 +77,11 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
|||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_targets = data[5]
|
||||
|
||||
# set stop targets view, we predict a single stop token per r frames prediction
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||
|
||||
current_step = num_iter + args.restore_step + \
|
||||
epoch * len(data_loader) + 1
|
||||
|
@ -94,13 +99,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
|||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
linear_input = linear_input.cuda()
|
||||
|
||||
stop_targets = stop_targets.cuda()
|
||||
|
||||
# forward pass
|
||||
mel_output, linear_output, alignments, stop_tokens =\
|
||||
model.forward(text_input, mel_input)
|
||||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_target)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
|
@ -208,11 +214,11 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
|
|||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
stop_targets = data[5]
|
||||
|
||||
# set stop targets view, we predict a single stop token per r frames prediction
|
||||
stop_target = stop_target.view(text_input.shape[0], stop_target.size(1) // c.r, -1)
|
||||
stop_target = (stop_target.sum(2) > 0.0).float()
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
|
@ -220,14 +226,14 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
|
|||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
linear_input = linear_input.cuda()
|
||||
stop_target = stop_target.cuda()
|
||||
stop_targets = stop_targets.cuda()
|
||||
|
||||
# forward pass
|
||||
mel_output, linear_output, alignments, stop_tokens =\
|
||||
model.forward(text_input, mel_spec)
|
||||
model.forward(text_input, mel_input)
|
||||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_target)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
|
|
Loading…
Reference in New Issue