bugfix in GST

This commit is contained in:
Edresson 2020-07-29 18:22:13 -03:00 committed by erogol
parent 8a1c113df6
commit 1d73566e4e
2 changed files with 4 additions and 4 deletions

View File

@ -508,7 +508,7 @@ def main(args): # pylint: disable=redefined-outer-name
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
if not speaker_mapping:
print("WARNING: speakers.json speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
if not speaker_mapping:
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
@ -559,8 +559,6 @@ def main(args): # pylint: disable=redefined-outer-name
# setup criterion
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
for name, _ in model.named_parameters():
print(name)
if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu')
@ -575,6 +573,8 @@ def main(args): # pylint: disable=redefined-outer-name
print(" > Partial model initialization.")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
# torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
# print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
model.load_state_dict(model_dict)
del model_dict

View File

@ -96,7 +96,7 @@ class StyleTokenLayer(nn.Module):
self.key_dim = embedding_dim // num_heads
self.style_tokens = nn.Parameter(
torch.FloatTensor(num_style_tokens, self.key_dim))
nn.init.orthogonal_(self.style_tokens)
nn.init.normal_(self.style_tokens, mean=0, std=0.5)
self.attention = MultiHeadAttention(
query_dim=self.query_dim,
key_dim=self.key_dim,