Decoder shape comments for Tacotron2, decoupled grad clip for stopnet and the rest of the network. Some variable renaming and bug fix for alignment score logging

This commit is contained in:
Eren Golge 2019-11-08 10:04:44 +01:00
parent 2966e3f2d1
commit 015f7780f4
4 changed files with 56 additions and 43 deletions

View File

@ -208,7 +208,14 @@ class Decoder(nn.Module):
return memory[:, :, self.memory_dim * (self.r - 1):] return memory[:, :, self.memory_dim * (self.r - 1):]
def decode(self, memory): def decode(self, memory):
'''
shapes:
- memory: B x r * self.memory_dim
'''
# self.context: B x D_en
# query_input: B x D_en + (r * self.memory_dim)
query_input = torch.cat((memory, self.context), -1) query_input = torch.cat((memory, self.context), -1)
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn
self.query, self.attention_rnn_cell_state = self.attention_rnn( self.query, self.attention_rnn_cell_state = self.attention_rnn(
query_input, (self.query, self.attention_rnn_cell_state)) query_input, (self.query, self.attention_rnn_cell_state))
self.query = F.dropout(self.query, self.p_attention_dropout, self.query = F.dropout(self.query, self.p_attention_dropout,
@ -216,29 +223,28 @@ class Decoder(nn.Module):
self.attention_rnn_cell_state = F.dropout( self.attention_rnn_cell_state = F.dropout(
self.attention_rnn_cell_state, self.p_attention_dropout, self.attention_rnn_cell_state, self.p_attention_dropout,
self.training) self.training)
# B x D_en
self.context = self.attention(self.query, self.inputs, self.context = self.attention(self.query, self.inputs,
self.processed_inputs, self.mask) self.processed_inputs, self.mask)
# B x (D_en + D_attn_rnn)
memory = torch.cat((self.query, self.context), -1) decoder_rnn_input = torch.cat((self.query, self.context), -1)
# self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn
self.decoder_hidden, self.decoder_cell = self.decoder_rnn( self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
memory, (self.decoder_hidden, self.decoder_cell)) decoder_rnn_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(self.decoder_hidden, self.decoder_hidden = F.dropout(self.decoder_hidden,
self.p_decoder_dropout, self.training) self.p_decoder_dropout, self.training)
self.decoder_cell = F.dropout(self.decoder_cell, # B x (D_decoder_rnn + D_en)
self.p_decoder_dropout, self.training)
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
dim=1) dim=1)
# B x (self.r * self.memory_dim)
decoder_output = self.linear_projection(decoder_hidden_context) decoder_output = self.linear_projection(decoder_hidden_context)
# B x (D_decoder_rnn + (self.r * self.memory_dim))
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
if self.separate_stopnet: if self.separate_stopnet:
stop_token = self.stopnet(stopnet_input.detach()) stop_token = self.stopnet(stopnet_input.detach())
else: else:
stop_token = self.stopnet(stopnet_input) stop_token = self.stopnet(stopnet_input)
# select outputs for the reduction rate self.r
decoder_output = decoder_output[:, :self.r * self.memory_dim] decoder_output = decoder_output[:, :self.r * self.memory_dim]
return decoder_output, self.attention.attention_weights, stop_token return decoder_output, self.attention.attention_weights, stop_token
@ -257,8 +263,8 @@ class Decoder(nn.Module):
outputs, stop_tokens, alignments = [], [], [] outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1: while len(outputs) < memories.size(0) - 1:
memory = memories[len(outputs)] memory = memories[len(outputs)]
mel_output, attention_weights, stop_token = self.decode(memory) decoder_output, attention_weights, stop_token = self.decode(memory)
outputs += [mel_output.squeeze(1)] outputs += [decoder_output.squeeze(1)]
stop_tokens += [stop_token.squeeze(1)] stop_tokens += [stop_token.squeeze(1)]
alignments += [attention_weights] alignments += [attention_weights]
@ -278,9 +284,9 @@ class Decoder(nn.Module):
memory = self.prenet(memory) memory = self.prenet(memory)
if speaker_embeddings is not None: if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1) memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, alignment, stop_token = self.decode(memory) decoder_output, alignment, stop_token = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data) stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)] outputs += [decoder_output.squeeze(1)]
stop_tokens += [stop_token] stop_tokens += [stop_token]
alignments += [alignment] alignments += [alignment]
@ -290,7 +296,7 @@ class Decoder(nn.Module):
print(" | > Decoder stopped with 'max_decoder_steps") print(" | > Decoder stopped with 'max_decoder_steps")
break break
memory = self._update_memory(mel_output) memory = self._update_memory(decoder_output)
t += 1 t += 1
outputs, stop_tokens, alignments = self._parse_outputs( outputs, stop_tokens, alignments = self._parse_outputs(
@ -314,9 +320,9 @@ class Decoder(nn.Module):
stop_flags = [True, False, False] stop_flags = [True, False, False]
while True: while True:
memory = self.prenet(self.memory_truncated) memory = self.prenet(self.memory_truncated)
mel_output, alignment, stop_token = self.decode(memory) decoder_output, alignment, stop_token = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data) stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)] outputs += [decoder_output.squeeze(1)]
stop_tokens += [stop_token] stop_tokens += [stop_token]
alignments += [alignment] alignments += [alignment]
@ -326,7 +332,7 @@ class Decoder(nn.Module):
print(" | > Decoder stopped with 'max_decoder_steps") print(" | > Decoder stopped with 'max_decoder_steps")
break break
self.memory_truncated = mel_output self.memory_truncated = decoder_output
t += 1 t += 1
outputs, stop_tokens, alignments = self._parse_outputs( outputs, stop_tokens, alignments = self._parse_outputs(
@ -343,7 +349,7 @@ class Decoder(nn.Module):
self._init_states(inputs, mask=None) self._init_states(inputs, mask=None)
memory = self.prenet(memory) memory = self.prenet(memory)
mel_output, stop_token, alignment = self.decode(memory) decoder_output, stop_token, alignment = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data) stop_token = torch.sigmoid(stop_token.data)
memory = mel_output memory = decoder_output
return mel_output, stop_token, alignment return decoder_output, stop_token, alignment

View File

@ -198,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
loss.backward() loss.backward()
optimizer, current_lr = adam_weight_decay(optimizer) optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model.decoder, c.grad_clip) grad_norm, grad_flag = check_update(model, c.grad_clip, ignore_stopnet=True)
optimizer.step() optimizer.step()
# compute alignment score # compute alignment score
@ -302,16 +302,15 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
end_time = time.time() end_time = time.time()
# print epoch stats # print epoch stats
print(" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " print(" | > EPOCH END -- GlobalStep:{} "
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} " "AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
"AvgStopLoss:{:.5f} EpochTime:{:.2f} " "AvgStopLoss:{:.5f} AvgAlignScore:{:3f} EpochTime:{:.2f} "
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format( "AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(
global_step, keep_avg['avg_postnet_loss'], global_step, keep_avg['avg_postnet_loss'],
keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'], keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'],
keep_avg['avg_align_score'], epoch_time, keep_avg['avg_align_score'], epoch_time,
keep_avg['avg_step_time'], keep_avg['avg_loader_time']), keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
flush=True) flush=True)
# Plot Epoch Stats # Plot Epoch Stats
if args.rank == 0: if args.rank == 0:
# Plot Training Epoch Stats # Plot Training Epoch Stats
@ -568,7 +567,7 @@ def main(args): # pylint: disable=redefined-outer-name
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST" criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
] else nn.MSELoss() ] else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss( criterion_st = nn.BCEWithLogitsLoss(
pos_weight=torch.tensor(20.0)) if c.stopnet else None pos_weight=torch.tensor(10)) if c.stopnet else None
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path) checkpoint = torch.load(args.restore_path)

View File

@ -150,9 +150,12 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
return best_loss return best_loss
def check_update(model, grad_clip): def check_update(model, grad_clip, ignore_stopnet=False):
r'''Check model gradient against unexpected jumps and failures''' r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False skip_flag = False
if ignore_stopnet:
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
if np.isinf(grad_norm): if np.isinf(grad_norm):
print(" | > Gradient is INF !!") print(" | > Gradient is INF !!")

View File

@ -11,6 +11,11 @@ class Logger(object):
def tb_model_weights(self, model, step): def tb_model_weights(self, model, step):
layer_num = 1 layer_num = 1
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.numel() == 1:
self.writer.add_scalar(
"layer{}-{}/value".format(layer_num, name),
param.max(), step)
else:
self.writer.add_scalar( self.writer.add_scalar(
"layer{}-{}/max".format(layer_num, name), "layer{}-{}/max".format(layer_num, name),
param.max(), step) param.max(), step)