mirror of https://github.com/coqui-ai/TTS.git
Decoder shape comments for Tacotron2, decoupled grad clip for stopnet and the rest of the network. Some variable renaming and bug fix for alignment score logging
This commit is contained in:
parent
2966e3f2d1
commit
015f7780f4
|
@ -208,7 +208,14 @@ class Decoder(nn.Module):
|
|||
return memory[:, :, self.memory_dim * (self.r - 1):]
|
||||
|
||||
def decode(self, memory):
|
||||
'''
|
||||
shapes:
|
||||
- memory: B x r * self.memory_dim
|
||||
'''
|
||||
# self.context: B x D_en
|
||||
# query_input: B x D_en + (r * self.memory_dim)
|
||||
query_input = torch.cat((memory, self.context), -1)
|
||||
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn
|
||||
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
||||
query_input, (self.query, self.attention_rnn_cell_state))
|
||||
self.query = F.dropout(self.query, self.p_attention_dropout,
|
||||
|
@ -216,29 +223,28 @@ class Decoder(nn.Module):
|
|||
self.attention_rnn_cell_state = F.dropout(
|
||||
self.attention_rnn_cell_state, self.p_attention_dropout,
|
||||
self.training)
|
||||
|
||||
# B x D_en
|
||||
self.context = self.attention(self.query, self.inputs,
|
||||
self.processed_inputs, self.mask)
|
||||
|
||||
memory = torch.cat((self.query, self.context), -1)
|
||||
# B x (D_en + D_attn_rnn)
|
||||
decoder_rnn_input = torch.cat((self.query, self.context), -1)
|
||||
# self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn
|
||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||
memory, (self.decoder_hidden, self.decoder_cell))
|
||||
decoder_rnn_input, (self.decoder_hidden, self.decoder_cell))
|
||||
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
||||
self.p_decoder_dropout, self.training)
|
||||
self.decoder_cell = F.dropout(self.decoder_cell,
|
||||
self.p_decoder_dropout, self.training)
|
||||
|
||||
# B x (D_decoder_rnn + D_en)
|
||||
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
|
||||
dim=1)
|
||||
|
||||
# B x (self.r * self.memory_dim)
|
||||
decoder_output = self.linear_projection(decoder_hidden_context)
|
||||
|
||||
# B x (D_decoder_rnn + (self.r * self.memory_dim))
|
||||
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
|
||||
|
||||
if self.separate_stopnet:
|
||||
stop_token = self.stopnet(stopnet_input.detach())
|
||||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
# select outputs for the reduction rate self.r
|
||||
decoder_output = decoder_output[:, :self.r * self.memory_dim]
|
||||
return decoder_output, self.attention.attention_weights, stop_token
|
||||
|
||||
|
@ -257,8 +263,8 @@ class Decoder(nn.Module):
|
|||
outputs, stop_tokens, alignments = [], [], []
|
||||
while len(outputs) < memories.size(0) - 1:
|
||||
memory = memories[len(outputs)]
|
||||
mel_output, attention_weights, stop_token = self.decode(memory)
|
||||
outputs += [mel_output.squeeze(1)]
|
||||
decoder_output, attention_weights, stop_token = self.decode(memory)
|
||||
outputs += [decoder_output.squeeze(1)]
|
||||
stop_tokens += [stop_token.squeeze(1)]
|
||||
alignments += [attention_weights]
|
||||
|
||||
|
@ -278,9 +284,9 @@ class Decoder(nn.Module):
|
|||
memory = self.prenet(memory)
|
||||
if speaker_embeddings is not None:
|
||||
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
||||
mel_output, alignment, stop_token = self.decode(memory)
|
||||
decoder_output, alignment, stop_token = self.decode(memory)
|
||||
stop_token = torch.sigmoid(stop_token.data)
|
||||
outputs += [mel_output.squeeze(1)]
|
||||
outputs += [decoder_output.squeeze(1)]
|
||||
stop_tokens += [stop_token]
|
||||
alignments += [alignment]
|
||||
|
||||
|
@ -290,7 +296,7 @@ class Decoder(nn.Module):
|
|||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
|
||||
memory = self._update_memory(mel_output)
|
||||
memory = self._update_memory(decoder_output)
|
||||
t += 1
|
||||
|
||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||
|
@ -314,9 +320,9 @@ class Decoder(nn.Module):
|
|||
stop_flags = [True, False, False]
|
||||
while True:
|
||||
memory = self.prenet(self.memory_truncated)
|
||||
mel_output, alignment, stop_token = self.decode(memory)
|
||||
decoder_output, alignment, stop_token = self.decode(memory)
|
||||
stop_token = torch.sigmoid(stop_token.data)
|
||||
outputs += [mel_output.squeeze(1)]
|
||||
outputs += [decoder_output.squeeze(1)]
|
||||
stop_tokens += [stop_token]
|
||||
alignments += [alignment]
|
||||
|
||||
|
@ -326,7 +332,7 @@ class Decoder(nn.Module):
|
|||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
|
||||
self.memory_truncated = mel_output
|
||||
self.memory_truncated = decoder_output
|
||||
t += 1
|
||||
|
||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||
|
@ -343,7 +349,7 @@ class Decoder(nn.Module):
|
|||
self._init_states(inputs, mask=None)
|
||||
|
||||
memory = self.prenet(memory)
|
||||
mel_output, stop_token, alignment = self.decode(memory)
|
||||
decoder_output, stop_token, alignment = self.decode(memory)
|
||||
stop_token = torch.sigmoid(stop_token.data)
|
||||
memory = mel_output
|
||||
return mel_output, stop_token, alignment
|
||||
memory = decoder_output
|
||||
return decoder_output, stop_token, alignment
|
||||
|
|
9
train.py
9
train.py
|
@ -198,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
|
||||
loss.backward()
|
||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||
grad_norm, _ = check_update(model.decoder, c.grad_clip)
|
||||
grad_norm, grad_flag = check_update(model, c.grad_clip, ignore_stopnet=True)
|
||||
optimizer.step()
|
||||
|
||||
# compute alignment score
|
||||
|
@ -302,16 +302,15 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
print(" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||
print(" | > EPOCH END -- GlobalStep:{} "
|
||||
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
|
||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||
"AvgStopLoss:{:.5f} AvgAlignScore:{:3f} EpochTime:{:.2f} "
|
||||
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(
|
||||
global_step, keep_avg['avg_postnet_loss'],
|
||||
keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'],
|
||||
keep_avg['avg_align_score'], epoch_time,
|
||||
keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
|
||||
flush=True)
|
||||
|
||||
# Plot Epoch Stats
|
||||
if args.rank == 0:
|
||||
# Plot Training Epoch Stats
|
||||
|
@ -568,7 +567,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
|
||||
] else nn.MSELoss()
|
||||
criterion_st = nn.BCEWithLogitsLoss(
|
||||
pos_weight=torch.tensor(20.0)) if c.stopnet else None
|
||||
pos_weight=torch.tensor(10)) if c.stopnet else None
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
|
|
|
@ -150,10 +150,13 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
|||
return best_loss
|
||||
|
||||
|
||||
def check_update(model, grad_clip):
|
||||
def check_update(model, grad_clip, ignore_stopnet=False):
|
||||
r'''Check model gradient against unexpected jumps and failures'''
|
||||
skip_flag = False
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
if ignore_stopnet:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
if np.isinf(grad_norm):
|
||||
print(" | > Gradient is INF !!")
|
||||
skip_flag = True
|
||||
|
|
|
@ -11,22 +11,27 @@ class Logger(object):
|
|||
def tb_model_weights(self, model, step):
|
||||
layer_num = 1
|
||||
for name, param in model.named_parameters():
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/max".format(layer_num, name),
|
||||
if param.numel() == 1:
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/value".format(layer_num, name),
|
||||
param.max(), step)
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/min".format(layer_num, name),
|
||||
param.min(), step)
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/mean".format(layer_num, name),
|
||||
param.mean(), step)
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/std".format(layer_num, name),
|
||||
param.std(), step)
|
||||
self.writer.add_histogram(
|
||||
"layer{}-{}/param".format(layer_num, name), param, step)
|
||||
self.writer.add_histogram(
|
||||
"layer{}-{}/grad".format(layer_num, name), param.grad, step)
|
||||
else:
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/max".format(layer_num, name),
|
||||
param.max(), step)
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/min".format(layer_num, name),
|
||||
param.min(), step)
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/mean".format(layer_num, name),
|
||||
param.mean(), step)
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/std".format(layer_num, name),
|
||||
param.std(), step)
|
||||
self.writer.add_histogram(
|
||||
"layer{}-{}/param".format(layer_num, name), param, step)
|
||||
self.writer.add_histogram(
|
||||
"layer{}-{}/grad".format(layer_num, name), param.grad, step)
|
||||
layer_num += 1
|
||||
|
||||
def dict_to_tb_scalar(self, scope_name, stats, step):
|
||||
|
|
Loading…
Reference in New Issue