mirror of https://github.com/coqui-ai/TTS.git
bugfix in GST
This commit is contained in:
parent
8a1c113df6
commit
1d73566e4e
|
@ -508,7 +508,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
prev_out_path = os.path.dirname(args.restore_path)
|
prev_out_path = os.path.dirname(args.restore_path)
|
||||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||||
if not speaker_mapping:
|
if not speaker_mapping:
|
||||||
print("WARNING: speakers.json speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
|
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
|
||||||
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
||||||
if not speaker_mapping:
|
if not speaker_mapping:
|
||||||
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
|
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
|
||||||
|
@ -559,8 +559,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
||||||
for name, _ in model.named_parameters():
|
|
||||||
print(name)
|
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||||
|
@ -575,6 +573,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
print(" > Partial model initialization.")
|
print(" > Partial model initialization.")
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||||
|
# torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
|
||||||
|
# print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
del model_dict
|
del model_dict
|
||||||
|
|
||||||
|
|
|
@ -96,7 +96,7 @@ class StyleTokenLayer(nn.Module):
|
||||||
self.key_dim = embedding_dim // num_heads
|
self.key_dim = embedding_dim // num_heads
|
||||||
self.style_tokens = nn.Parameter(
|
self.style_tokens = nn.Parameter(
|
||||||
torch.FloatTensor(num_style_tokens, self.key_dim))
|
torch.FloatTensor(num_style_tokens, self.key_dim))
|
||||||
nn.init.orthogonal_(self.style_tokens)
|
nn.init.normal_(self.style_tokens, mean=0, std=0.5)
|
||||||
self.attention = MultiHeadAttention(
|
self.attention = MultiHeadAttention(
|
||||||
query_dim=self.query_dim,
|
query_dim=self.query_dim,
|
||||||
key_dim=self.key_dim,
|
key_dim=self.key_dim,
|
||||||
|
|
Loading…
Reference in New Issue