mirror of https://github.com/coqui-ai/TTS.git
Decoder shape comments for Tacotron2, decoupled grad clip for stopnet and the rest of the network. Some variable renaming and bug fix for alignment score logging
This commit is contained in:
parent
2966e3f2d1
commit
015f7780f4
|
@ -208,7 +208,14 @@ class Decoder(nn.Module):
|
||||||
return memory[:, :, self.memory_dim * (self.r - 1):]
|
return memory[:, :, self.memory_dim * (self.r - 1):]
|
||||||
|
|
||||||
def decode(self, memory):
|
def decode(self, memory):
|
||||||
|
'''
|
||||||
|
shapes:
|
||||||
|
- memory: B x r * self.memory_dim
|
||||||
|
'''
|
||||||
|
# self.context: B x D_en
|
||||||
|
# query_input: B x D_en + (r * self.memory_dim)
|
||||||
query_input = torch.cat((memory, self.context), -1)
|
query_input = torch.cat((memory, self.context), -1)
|
||||||
|
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn
|
||||||
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
||||||
query_input, (self.query, self.attention_rnn_cell_state))
|
query_input, (self.query, self.attention_rnn_cell_state))
|
||||||
self.query = F.dropout(self.query, self.p_attention_dropout,
|
self.query = F.dropout(self.query, self.p_attention_dropout,
|
||||||
|
@ -216,29 +223,28 @@ class Decoder(nn.Module):
|
||||||
self.attention_rnn_cell_state = F.dropout(
|
self.attention_rnn_cell_state = F.dropout(
|
||||||
self.attention_rnn_cell_state, self.p_attention_dropout,
|
self.attention_rnn_cell_state, self.p_attention_dropout,
|
||||||
self.training)
|
self.training)
|
||||||
|
# B x D_en
|
||||||
self.context = self.attention(self.query, self.inputs,
|
self.context = self.attention(self.query, self.inputs,
|
||||||
self.processed_inputs, self.mask)
|
self.processed_inputs, self.mask)
|
||||||
|
# B x (D_en + D_attn_rnn)
|
||||||
memory = torch.cat((self.query, self.context), -1)
|
decoder_rnn_input = torch.cat((self.query, self.context), -1)
|
||||||
|
# self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn
|
||||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||||
memory, (self.decoder_hidden, self.decoder_cell))
|
decoder_rnn_input, (self.decoder_hidden, self.decoder_cell))
|
||||||
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
||||||
self.p_decoder_dropout, self.training)
|
self.p_decoder_dropout, self.training)
|
||||||
self.decoder_cell = F.dropout(self.decoder_cell,
|
# B x (D_decoder_rnn + D_en)
|
||||||
self.p_decoder_dropout, self.training)
|
|
||||||
|
|
||||||
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
|
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
|
||||||
dim=1)
|
dim=1)
|
||||||
|
# B x (self.r * self.memory_dim)
|
||||||
decoder_output = self.linear_projection(decoder_hidden_context)
|
decoder_output = self.linear_projection(decoder_hidden_context)
|
||||||
|
# B x (D_decoder_rnn + (self.r * self.memory_dim))
|
||||||
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
|
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
|
||||||
|
|
||||||
if self.separate_stopnet:
|
if self.separate_stopnet:
|
||||||
stop_token = self.stopnet(stopnet_input.detach())
|
stop_token = self.stopnet(stopnet_input.detach())
|
||||||
else:
|
else:
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
|
# select outputs for the reduction rate self.r
|
||||||
decoder_output = decoder_output[:, :self.r * self.memory_dim]
|
decoder_output = decoder_output[:, :self.r * self.memory_dim]
|
||||||
return decoder_output, self.attention.attention_weights, stop_token
|
return decoder_output, self.attention.attention_weights, stop_token
|
||||||
|
|
||||||
|
@ -257,8 +263,8 @@ class Decoder(nn.Module):
|
||||||
outputs, stop_tokens, alignments = [], [], []
|
outputs, stop_tokens, alignments = [], [], []
|
||||||
while len(outputs) < memories.size(0) - 1:
|
while len(outputs) < memories.size(0) - 1:
|
||||||
memory = memories[len(outputs)]
|
memory = memories[len(outputs)]
|
||||||
mel_output, attention_weights, stop_token = self.decode(memory)
|
decoder_output, attention_weights, stop_token = self.decode(memory)
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [decoder_output.squeeze(1)]
|
||||||
stop_tokens += [stop_token.squeeze(1)]
|
stop_tokens += [stop_token.squeeze(1)]
|
||||||
alignments += [attention_weights]
|
alignments += [attention_weights]
|
||||||
|
|
||||||
|
@ -278,9 +284,9 @@ class Decoder(nn.Module):
|
||||||
memory = self.prenet(memory)
|
memory = self.prenet(memory)
|
||||||
if speaker_embeddings is not None:
|
if speaker_embeddings is not None:
|
||||||
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
||||||
mel_output, alignment, stop_token = self.decode(memory)
|
decoder_output, alignment, stop_token = self.decode(memory)
|
||||||
stop_token = torch.sigmoid(stop_token.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [decoder_output.squeeze(1)]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
|
@ -290,7 +296,7 @@ class Decoder(nn.Module):
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
break
|
break
|
||||||
|
|
||||||
memory = self._update_memory(mel_output)
|
memory = self._update_memory(decoder_output)
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||||
|
@ -314,9 +320,9 @@ class Decoder(nn.Module):
|
||||||
stop_flags = [True, False, False]
|
stop_flags = [True, False, False]
|
||||||
while True:
|
while True:
|
||||||
memory = self.prenet(self.memory_truncated)
|
memory = self.prenet(self.memory_truncated)
|
||||||
mel_output, alignment, stop_token = self.decode(memory)
|
decoder_output, alignment, stop_token = self.decode(memory)
|
||||||
stop_token = torch.sigmoid(stop_token.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [decoder_output.squeeze(1)]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
|
@ -326,7 +332,7 @@ class Decoder(nn.Module):
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
break
|
break
|
||||||
|
|
||||||
self.memory_truncated = mel_output
|
self.memory_truncated = decoder_output
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||||
|
@ -343,7 +349,7 @@ class Decoder(nn.Module):
|
||||||
self._init_states(inputs, mask=None)
|
self._init_states(inputs, mask=None)
|
||||||
|
|
||||||
memory = self.prenet(memory)
|
memory = self.prenet(memory)
|
||||||
mel_output, stop_token, alignment = self.decode(memory)
|
decoder_output, stop_token, alignment = self.decode(memory)
|
||||||
stop_token = torch.sigmoid(stop_token.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
memory = mel_output
|
memory = decoder_output
|
||||||
return mel_output, stop_token, alignment
|
return decoder_output, stop_token, alignment
|
||||||
|
|
9
train.py
9
train.py
|
@ -198,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||||
grad_norm, _ = check_update(model.decoder, c.grad_clip)
|
grad_norm, grad_flag = check_update(model, c.grad_clip, ignore_stopnet=True)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# compute alignment score
|
# compute alignment score
|
||||||
|
@ -302,16 +302,15 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
print(" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
print(" | > EPOCH END -- GlobalStep:{} "
|
||||||
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
|
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
|
||||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
"AvgStopLoss:{:.5f} AvgAlignScore:{:3f} EpochTime:{:.2f} "
|
||||||
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(
|
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(
|
||||||
global_step, keep_avg['avg_postnet_loss'],
|
global_step, keep_avg['avg_postnet_loss'],
|
||||||
keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'],
|
keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'],
|
||||||
keep_avg['avg_align_score'], epoch_time,
|
keep_avg['avg_align_score'], epoch_time,
|
||||||
keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
|
keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
# Plot Epoch Stats
|
# Plot Epoch Stats
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
# Plot Training Epoch Stats
|
# Plot Training Epoch Stats
|
||||||
|
@ -568,7 +567,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
|
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
|
||||||
] else nn.MSELoss()
|
] else nn.MSELoss()
|
||||||
criterion_st = nn.BCEWithLogitsLoss(
|
criterion_st = nn.BCEWithLogitsLoss(
|
||||||
pos_weight=torch.tensor(20.0)) if c.stopnet else None
|
pos_weight=torch.tensor(10)) if c.stopnet else None
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
|
|
|
@ -150,9 +150,12 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||||
return best_loss
|
return best_loss
|
||||||
|
|
||||||
|
|
||||||
def check_update(model, grad_clip):
|
def check_update(model, grad_clip, ignore_stopnet=False):
|
||||||
r'''Check model gradient against unexpected jumps and failures'''
|
r'''Check model gradient against unexpected jumps and failures'''
|
||||||
skip_flag = False
|
skip_flag = False
|
||||||
|
if ignore_stopnet:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
|
||||||
|
else:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||||
if np.isinf(grad_norm):
|
if np.isinf(grad_norm):
|
||||||
print(" | > Gradient is INF !!")
|
print(" | > Gradient is INF !!")
|
||||||
|
|
|
@ -11,6 +11,11 @@ class Logger(object):
|
||||||
def tb_model_weights(self, model, step):
|
def tb_model_weights(self, model, step):
|
||||||
layer_num = 1
|
layer_num = 1
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
|
if param.numel() == 1:
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"layer{}-{}/value".format(layer_num, name),
|
||||||
|
param.max(), step)
|
||||||
|
else:
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"layer{}-{}/max".format(layer_num, name),
|
"layer{}-{}/max".format(layer_num, name),
|
||||||
param.max(), step)
|
param.max(), step)
|
||||||
|
|
Loading…
Reference in New Issue