diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 780d8e7f..4046b06b 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -208,7 +208,14 @@ class Decoder(nn.Module): return memory[:, :, self.memory_dim * (self.r - 1):] def decode(self, memory): + ''' + shapes: + - memory: B x r * self.memory_dim + ''' + # self.context: B x D_en + # query_input: B x D_en + (r * self.memory_dim) query_input = torch.cat((memory, self.context), -1) + # self.query and self.attention_rnn_cell_state : B x D_attn_rnn self.query, self.attention_rnn_cell_state = self.attention_rnn( query_input, (self.query, self.attention_rnn_cell_state)) self.query = F.dropout(self.query, self.p_attention_dropout, @@ -216,29 +223,28 @@ class Decoder(nn.Module): self.attention_rnn_cell_state = F.dropout( self.attention_rnn_cell_state, self.p_attention_dropout, self.training) - + # B x D_en self.context = self.attention(self.query, self.inputs, self.processed_inputs, self.mask) - - memory = torch.cat((self.query, self.context), -1) + # B x (D_en + D_attn_rnn) + decoder_rnn_input = torch.cat((self.query, self.context), -1) + # self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn self.decoder_hidden, self.decoder_cell = self.decoder_rnn( - memory, (self.decoder_hidden, self.decoder_cell)) + decoder_rnn_input, (self.decoder_hidden, self.decoder_cell)) self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training) - self.decoder_cell = F.dropout(self.decoder_cell, - self.p_decoder_dropout, self.training) - + # B x (D_decoder_rnn + D_en) decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), dim=1) - + # B x (self.r * self.memory_dim) decoder_output = self.linear_projection(decoder_hidden_context) - + # B x (D_decoder_rnn + (self.r * self.memory_dim)) stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) - if self.separate_stopnet: stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) + # select outputs for the reduction rate self.r decoder_output = decoder_output[:, :self.r * self.memory_dim] return decoder_output, self.attention.attention_weights, stop_token @@ -257,8 +263,8 @@ class Decoder(nn.Module): outputs, stop_tokens, alignments = [], [], [] while len(outputs) < memories.size(0) - 1: memory = memories[len(outputs)] - mel_output, attention_weights, stop_token = self.decode(memory) - outputs += [mel_output.squeeze(1)] + decoder_output, attention_weights, stop_token = self.decode(memory) + outputs += [decoder_output.squeeze(1)] stop_tokens += [stop_token.squeeze(1)] alignments += [attention_weights] @@ -278,9 +284,9 @@ class Decoder(nn.Module): memory = self.prenet(memory) if speaker_embeddings is not None: memory = torch.cat([memory, speaker_embeddings], dim=-1) - mel_output, alignment, stop_token = self.decode(memory) + decoder_output, alignment, stop_token = self.decode(memory) stop_token = torch.sigmoid(stop_token.data) - outputs += [mel_output.squeeze(1)] + outputs += [decoder_output.squeeze(1)] stop_tokens += [stop_token] alignments += [alignment] @@ -290,7 +296,7 @@ class Decoder(nn.Module): print(" | > Decoder stopped with 'max_decoder_steps") break - memory = self._update_memory(mel_output) + memory = self._update_memory(decoder_output) t += 1 outputs, stop_tokens, alignments = self._parse_outputs( @@ -314,9 +320,9 @@ class Decoder(nn.Module): stop_flags = [True, False, False] while True: memory = self.prenet(self.memory_truncated) - mel_output, alignment, stop_token = self.decode(memory) + decoder_output, alignment, stop_token = self.decode(memory) stop_token = torch.sigmoid(stop_token.data) - outputs += [mel_output.squeeze(1)] + outputs += [decoder_output.squeeze(1)] stop_tokens += [stop_token] alignments += [alignment] @@ -326,7 +332,7 @@ class Decoder(nn.Module): print(" | > Decoder stopped with 'max_decoder_steps") break - self.memory_truncated = mel_output + self.memory_truncated = decoder_output t += 1 outputs, stop_tokens, alignments = self._parse_outputs( @@ -343,7 +349,7 @@ class Decoder(nn.Module): self._init_states(inputs, mask=None) memory = self.prenet(memory) - mel_output, stop_token, alignment = self.decode(memory) + decoder_output, stop_token, alignment = self.decode(memory) stop_token = torch.sigmoid(stop_token.data) - memory = mel_output - return mel_output, stop_token, alignment + memory = decoder_output + return decoder_output, stop_token, alignment diff --git a/train.py b/train.py index 976577fd..50942a7f 100644 --- a/train.py +++ b/train.py @@ -198,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, loss.backward() optimizer, current_lr = adam_weight_decay(optimizer) - grad_norm, _ = check_update(model.decoder, c.grad_clip) + grad_norm, grad_flag = check_update(model, c.grad_clip, ignore_stopnet=True) optimizer.step() # compute alignment score @@ -302,16 +302,15 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, end_time = time.time() # print epoch stats - print(" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " + print(" | > EPOCH END -- GlobalStep:{} " "AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} " - "AvgStopLoss:{:.5f} EpochTime:{:.2f} " + "AvgStopLoss:{:.5f} AvgAlignScore:{:3f} EpochTime:{:.2f} " "AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format( global_step, keep_avg['avg_postnet_loss'], keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'], keep_avg['avg_align_score'], epoch_time, keep_avg['avg_step_time'], keep_avg['avg_loader_time']), flush=True) - # Plot Epoch Stats if args.rank == 0: # Plot Training Epoch Stats @@ -568,7 +567,7 @@ def main(args): # pylint: disable=redefined-outer-name criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST" ] else nn.MSELoss() criterion_st = nn.BCEWithLogitsLoss( - pos_weight=torch.tensor(20.0)) if c.stopnet else None + pos_weight=torch.tensor(10)) if c.stopnet else None if args.restore_path: checkpoint = torch.load(args.restore_path) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 972513d6..2cad3a7c 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -150,10 +150,13 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, return best_loss -def check_update(model, grad_clip): +def check_update(model, grad_clip, ignore_stopnet=False): r'''Check model gradient against unexpected jumps and failures''' skip_flag = False - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + if ignore_stopnet: + grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) if np.isinf(grad_norm): print(" | > Gradient is INF !!") skip_flag = True diff --git a/utils/logger.py b/utils/logger.py index be523d36..f83c9d0f 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -11,22 +11,27 @@ class Logger(object): def tb_model_weights(self, model, step): layer_num = 1 for name, param in model.named_parameters(): - self.writer.add_scalar( - "layer{}-{}/max".format(layer_num, name), + if param.numel() == 1: + self.writer.add_scalar( + "layer{}-{}/value".format(layer_num, name), param.max(), step) - self.writer.add_scalar( - "layer{}-{}/min".format(layer_num, name), - param.min(), step) - self.writer.add_scalar( - "layer{}-{}/mean".format(layer_num, name), - param.mean(), step) - self.writer.add_scalar( - "layer{}-{}/std".format(layer_num, name), - param.std(), step) - self.writer.add_histogram( - "layer{}-{}/param".format(layer_num, name), param, step) - self.writer.add_histogram( - "layer{}-{}/grad".format(layer_num, name), param.grad, step) + else: + self.writer.add_scalar( + "layer{}-{}/max".format(layer_num, name), + param.max(), step) + self.writer.add_scalar( + "layer{}-{}/min".format(layer_num, name), + param.min(), step) + self.writer.add_scalar( + "layer{}-{}/mean".format(layer_num, name), + param.mean(), step) + self.writer.add_scalar( + "layer{}-{}/std".format(layer_num, name), + param.std(), step) + self.writer.add_histogram( + "layer{}-{}/param".format(layer_num, name), param, step) + self.writer.add_histogram( + "layer{}-{}/grad".format(layer_num, name), param.grad, step) layer_num += 1 def dict_to_tb_scalar(self, scope_name, stats, step):