diff --git a/layers/custom_layers.py b/layers/custom_layers.py index 802091e8..d659efb2 100644 --- a/layers/custom_layers.py +++ b/layers/custom_layers.py @@ -4,23 +4,23 @@ from torch.autograd import Variable from torch import nn -class StopProjection(nn.Module): - r""" Simple projection layer to predict the "stop token" +# class StopProjection(nn.Module): +# r""" Simple projection layer to predict the "stop token" - Args: - in_features (int): size of the input vector - out_features (int or list): size of each output vector. aka number - of predicted frames. - """ +# Args: +# in_features (int): size of the input vector +# out_features (int or list): size of each output vector. aka number +# of predicted frames. +# """ - def __init__(self, in_features, out_features): - super(StopProjection, self).__init__() - self.linear = nn.Linear(in_features, out_features) - self.dropout = nn.Dropout(0.5) - self.sigmoid = nn.Sigmoid() +# def __init__(self, in_features, out_features): +# super(StopProjection, self).__init__() +# self.linear = nn.Linear(in_features, out_features) +# self.dropout = nn.Dropout(0.5) +# self.sigmoid = nn.Sigmoid() - def forward(self, inputs): - out = self.dropout(inputs) - out = self.linear(out) - out = self.sigmoid(out) - return out \ No newline at end of file +# def forward(self, inputs): +# out = self.dropout(inputs) +# out = self.linear(out) +# out = self.sigmoid(out) +# return out \ No newline at end of file diff --git a/layers/tacotron.py b/layers/tacotron.py index e9a40b24..51548287 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -5,7 +5,6 @@ from torch import nn from .attention import AttentionRNN from .attention import get_mask_from_lengths -from .custom_layers import StopProjection class Prenet(nn.Module): r""" Prenet as explained at https://arxiv.org/abs/1703.10135. @@ -233,8 +232,6 @@ class Decoder(nn.Module): [nn.GRUCell(256, 256) for _ in range(2)]) # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) - # RNN_state | attention_context -> |Linear| -> stop_token - self.stop_token = StopProjection(256 + in_features, r) def forward(self, inputs, memory=None): """ @@ -286,7 +283,6 @@ class Decoder(nn.Module): outputs = [] alignments = [] - stop_outputs = [] t = 0 memory_input = initial_memory @@ -323,18 +319,13 @@ class Decoder(nn.Module): decoder_input = decoder_rnn_hiddens[idx] + decoder_input output = decoder_input - stop_token_input = decoder_input - # stop token prediction - stop_token_input = torch.cat((output, current_context_vec), -1) - stop_output = self.stop_token(stop_token_input) # predict mel vectors from decoder vectors output = self.proj_to_mel(output) outputs += [output] alignments += [alignment] - stop_outputs += [stop_output] t += 1 @@ -354,9 +345,8 @@ class Decoder(nn.Module): # Back to batch first alignments = torch.stack(alignments).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() - stop_outputs = torch.stack(stop_outputs).transpose(0, 1).contiguous() - return outputs, alignments, stop_outputs + return outputs, alignments def is_end_of_frames(output, eps=0.2): #0.2 diff --git a/train.py b/train.py index 3531771f..3e1a75c7 100644 --- a/train.py +++ b/train.py @@ -63,12 +63,11 @@ def signal_handler(signal, frame): sys.exit(1) -def train(model, criterion, critetion_stop, data_loader, optimizer, epoch): +def train(model, criterion, data_loader, optimizer, epoch): model = model.train() epoch_time = 0 avg_linear_loss = 0 avg_mel_loss = 0 - avg_stop_loss = 0 print(" | > Epoch {}/{}".format(epoch, c.epochs)) progbar = Progbar(len(data_loader.dataset) / c.batch_size) @@ -81,7 +80,6 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch): text_lengths = data[1] linear_input = data[2] mel_input = data[3] - stop_targets = data[4] current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1 @@ -95,7 +93,6 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch): # convert inputs to variables text_input_var = Variable(text_input) mel_spec_var = Variable(mel_input) - stop_targets_var = Variable(stop_targets) linear_spec_var = Variable(linear_input, volatile=True) # sort sequence by length for curriculum learning @@ -112,10 +109,9 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch): text_input_var = text_input_var.cuda() mel_spec_var = mel_spec_var.cuda() linear_spec_var = linear_spec_var.cuda() - stop_targets_var = stop_targets_var.cuda() # forward pass - mel_output, linear_output, alignments, stop_output =\ + mel_output, linear_output, alignments =\ model.forward(text_input_var, mel_spec_var) # loss computation @@ -123,8 +119,7 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch): linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[: ,: ,:n_priority_freq]) - stop_loss = critetion_stop(stop_output, stop_targets_var) - loss = mel_loss + linear_loss + 0.25*stop_loss + loss = mel_loss + linear_loss # backpass and check the grad norm loss.backward() @@ -141,7 +136,6 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch): # update progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), ('linear_loss', linear_loss.data[0]), - ('stop_loss', stop_loss.data[0]), ('mel_loss', mel_loss.data[0]), ('grad_norm', grad_norm)]) @@ -150,7 +144,6 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch): tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0], current_step) tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step) - tb.add_scalar('TrainIterLoss/StopLoss', stop_loss.data[0], current_step) tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'], current_step) tb.add_scalar('Params/GradNorm', grad_norm, current_step) @@ -191,21 +184,19 @@ def train(model, criterion, critetion_stop, data_loader, optimizer, epoch): avg_linear_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1) - avg_stop_loss /= (num_iter + 1) - avg_total_loss = avg_mel_loss + avg_linear_loss + 0.25*avg_stop_loss + avg_total_loss = avg_mel_loss + avg_linear_loss # Plot Training Epoch Stats tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step) tb.add_scalar('TrainEpochLoss/LinearLoss', linear_loss.data[0], current_step) tb.add_scalar('TrainEpochLoss/MelLoss', mel_loss.data[0], current_step) - tb.add_scalar('TrainEpochLoss/StopLoss', stop_loss.data[0], current_step) tb.add_scalar('Time/EpochTime', epoch_time, epoch) epoch_time = 0 return avg_linear_loss, current_step -def evaluate(model, criterion, criterion_stop, data_loader, current_step): +def evaluate(model, criterion, data_loader, current_step): model = model.eval() epoch_time = 0 @@ -215,7 +206,6 @@ def evaluate(model, criterion, criterion_stop, data_loader, current_step): avg_linear_loss = 0 avg_mel_loss = 0 - avg_stop_loss = 0 for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -225,44 +215,38 @@ def evaluate(model, criterion, criterion_stop, data_loader, current_step): text_lengths = data[1] linear_input = data[2] mel_input = data[3] - stop_targets = data[4] # convert inputs to variables text_input_var = Variable(text_input) mel_spec_var = Variable(mel_input) linear_spec_var = Variable(linear_input, volatile=True) - stop_targets_var = Variable(stop_targets) # dispatch data to GPU if use_cuda: text_input_var = text_input_var.cuda() mel_spec_var = mel_spec_var.cuda() linear_spec_var = linear_spec_var.cuda() - stop_targets_var = stop_targets_var.cuda() # forward pass - mel_output, linear_output, alignments, stop_output = model.forward(text_input_var, mel_spec_var) + mel_output, linear_output, alignments = model.forward(text_input_var, mel_spec_var) # loss computation mel_loss = criterion(mel_output, mel_spec_var) linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[: ,: ,:n_priority_freq]) - stop_loss = criterion_stop(stop_output, stop_targets_var) - loss = mel_loss + linear_loss + 0.25*stop_loss + loss = mel_loss + linear_loss step_time = time.time() - start_time epoch_time += step_time # update progbar.update(num_iter+1, values=[('total_loss', loss.data[0]), - ('stop_loss', stop_loss.data[0]), ('linear_loss', linear_loss.data[0]), ('mel_loss', mel_loss.data[0])]) avg_linear_loss += linear_loss.data[0] avg_mel_loss += mel_loss.data[0] - avg_stop_loss += stop_loss.data[0] # Diagnostic visualizations idx = np.random.randint(mel_input.shape[0]) @@ -294,14 +278,12 @@ def evaluate(model, criterion, criterion_stop, data_loader, current_step): # compute average losses avg_linear_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1) - avg_stop_loss /= (num_iter + 1) - avg_total_loss = avg_mel_loss + avg_linear_loss + 0.25*avg_stop_loss + avg_total_loss = avg_mel_loss + avg_linear_loss # Plot Learning Stats tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step) tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step) tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step) - tb.add_scalar('ValEpochLoss/StopLoss', avg_stop_loss, current_step) return avg_linear_loss @@ -359,10 +341,8 @@ def main(args): if use_cuda: criterion = nn.L1Loss().cuda() - criterion_stop = nn.BCELoss().cuda() else: criterion = nn.L1Loss() - criterion_stop = nn.BCELoss() if args.restore_path: checkpoint = torch.load(args.restore_path) @@ -390,8 +370,8 @@ def main(args): best_loss = float('inf') for epoch in range(0, c.epochs): - train_loss, current_step = train(model, criterion, criterion_stop, train_loader, optimizer, epoch) - val_loss = evaluate(model, criterion, criterion_stop, val_loader, current_step) + train_loss, current_step = train(model, criterion, train_loader, optimizer, epoch) + val_loss = evaluate(model, criterion, val_loader, current_step) best_loss = save_best_model(model, optimizer, val_loss, best_loss, OUT_PATH, current_step, epoch)