Decoder shape comments for Tacotron2, decoupled grad clip for stopnet and the rest of the network. Some variable renaming and bug fix for alignment score logging

This commit is contained in:
Eren Golge 2019-11-08 10:04:44 +01:00
parent 2966e3f2d1
commit 015f7780f4
4 changed files with 56 additions and 43 deletions

View File

@ -208,7 +208,14 @@ class Decoder(nn.Module):
return memory[:, :, self.memory_dim * (self.r - 1):]
def decode(self, memory):
'''
shapes:
- memory: B x r * self.memory_dim
'''
# self.context: B x D_en
# query_input: B x D_en + (r * self.memory_dim)
query_input = torch.cat((memory, self.context), -1)
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn
self.query, self.attention_rnn_cell_state = self.attention_rnn(
query_input, (self.query, self.attention_rnn_cell_state))
self.query = F.dropout(self.query, self.p_attention_dropout,
@ -216,29 +223,28 @@ class Decoder(nn.Module):
self.attention_rnn_cell_state = F.dropout(
self.attention_rnn_cell_state, self.p_attention_dropout,
self.training)
# B x D_en
self.context = self.attention(self.query, self.inputs,
self.processed_inputs, self.mask)
memory = torch.cat((self.query, self.context), -1)
# B x (D_en + D_attn_rnn)
decoder_rnn_input = torch.cat((self.query, self.context), -1)
# self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
memory, (self.decoder_hidden, self.decoder_cell))
decoder_rnn_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(self.decoder_hidden,
self.p_decoder_dropout, self.training)
self.decoder_cell = F.dropout(self.decoder_cell,
self.p_decoder_dropout, self.training)
# B x (D_decoder_rnn + D_en)
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
dim=1)
# B x (self.r * self.memory_dim)
decoder_output = self.linear_projection(decoder_hidden_context)
# B x (D_decoder_rnn + (self.r * self.memory_dim))
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
if self.separate_stopnet:
stop_token = self.stopnet(stopnet_input.detach())
else:
stop_token = self.stopnet(stopnet_input)
# select outputs for the reduction rate self.r
decoder_output = decoder_output[:, :self.r * self.memory_dim]
return decoder_output, self.attention.attention_weights, stop_token
@ -257,8 +263,8 @@ class Decoder(nn.Module):
outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1:
memory = memories[len(outputs)]
mel_output, attention_weights, stop_token = self.decode(memory)
outputs += [mel_output.squeeze(1)]
decoder_output, attention_weights, stop_token = self.decode(memory)
outputs += [decoder_output.squeeze(1)]
stop_tokens += [stop_token.squeeze(1)]
alignments += [attention_weights]
@ -278,9 +284,9 @@ class Decoder(nn.Module):
memory = self.prenet(memory)
if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, alignment, stop_token = self.decode(memory)
decoder_output, alignment, stop_token = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)]
outputs += [decoder_output.squeeze(1)]
stop_tokens += [stop_token]
alignments += [alignment]
@ -290,7 +296,7 @@ class Decoder(nn.Module):
print(" | > Decoder stopped with 'max_decoder_steps")
break
memory = self._update_memory(mel_output)
memory = self._update_memory(decoder_output)
t += 1
outputs, stop_tokens, alignments = self._parse_outputs(
@ -314,9 +320,9 @@ class Decoder(nn.Module):
stop_flags = [True, False, False]
while True:
memory = self.prenet(self.memory_truncated)
mel_output, alignment, stop_token = self.decode(memory)
decoder_output, alignment, stop_token = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)]
outputs += [decoder_output.squeeze(1)]
stop_tokens += [stop_token]
alignments += [alignment]
@ -326,7 +332,7 @@ class Decoder(nn.Module):
print(" | > Decoder stopped with 'max_decoder_steps")
break
self.memory_truncated = mel_output
self.memory_truncated = decoder_output
t += 1
outputs, stop_tokens, alignments = self._parse_outputs(
@ -343,7 +349,7 @@ class Decoder(nn.Module):
self._init_states(inputs, mask=None)
memory = self.prenet(memory)
mel_output, stop_token, alignment = self.decode(memory)
decoder_output, stop_token, alignment = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data)
memory = mel_output
return mel_output, stop_token, alignment
memory = decoder_output
return decoder_output, stop_token, alignment

View File

@ -198,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
loss.backward()
optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model.decoder, c.grad_clip)
grad_norm, grad_flag = check_update(model, c.grad_clip, ignore_stopnet=True)
optimizer.step()
# compute alignment score
@ -302,16 +302,15 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
end_time = time.time()
# print epoch stats
print(" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
print(" | > EPOCH END -- GlobalStep:{} "
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
"AvgStopLoss:{:.5f} AvgAlignScore:{:3f} EpochTime:{:.2f} "
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(
global_step, keep_avg['avg_postnet_loss'],
keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'],
keep_avg['avg_align_score'], epoch_time,
keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
flush=True)
# Plot Epoch Stats
if args.rank == 0:
# Plot Training Epoch Stats
@ -568,7 +567,7 @@ def main(args): # pylint: disable=redefined-outer-name
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
] else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss(
pos_weight=torch.tensor(20.0)) if c.stopnet else None
pos_weight=torch.tensor(10)) if c.stopnet else None
if args.restore_path:
checkpoint = torch.load(args.restore_path)

View File

@ -150,10 +150,13 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
return best_loss
def check_update(model, grad_clip):
def check_update(model, grad_clip, ignore_stopnet=False):
r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
if ignore_stopnet:
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
if np.isinf(grad_norm):
print(" | > Gradient is INF !!")
skip_flag = True

View File

@ -11,22 +11,27 @@ class Logger(object):
def tb_model_weights(self, model, step):
layer_num = 1
for name, param in model.named_parameters():
self.writer.add_scalar(
"layer{}-{}/max".format(layer_num, name),
if param.numel() == 1:
self.writer.add_scalar(
"layer{}-{}/value".format(layer_num, name),
param.max(), step)
self.writer.add_scalar(
"layer{}-{}/min".format(layer_num, name),
param.min(), step)
self.writer.add_scalar(
"layer{}-{}/mean".format(layer_num, name),
param.mean(), step)
self.writer.add_scalar(
"layer{}-{}/std".format(layer_num, name),
param.std(), step)
self.writer.add_histogram(
"layer{}-{}/param".format(layer_num, name), param, step)
self.writer.add_histogram(
"layer{}-{}/grad".format(layer_num, name), param.grad, step)
else:
self.writer.add_scalar(
"layer{}-{}/max".format(layer_num, name),
param.max(), step)
self.writer.add_scalar(
"layer{}-{}/min".format(layer_num, name),
param.min(), step)
self.writer.add_scalar(
"layer{}-{}/mean".format(layer_num, name),
param.mean(), step)
self.writer.add_scalar(
"layer{}-{}/std".format(layer_num, name),
param.std(), step)
self.writer.add_histogram(
"layer{}-{}/param".format(layer_num, name), param, step)
self.writer.add_histogram(
"layer{}-{}/grad".format(layer_num, name), param.grad, step)
layer_num += 1
def dict_to_tb_scalar(self, scope_name, stats, step):