From 1d73566e4e2fe9a4f5e799b80f8506a0f92768e3 Mon Sep 17 00:00:00 2001 From: Edresson Date: Wed, 29 Jul 2020 18:22:13 -0300 Subject: [PATCH] bugfix in GST --- mozilla_voice_tts/bin/train_tts.py | 6 +++--- mozilla_voice_tts/tts/layers/gst_layers.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mozilla_voice_tts/bin/train_tts.py b/mozilla_voice_tts/bin/train_tts.py index 0642f290..daa517b9 100644 --- a/mozilla_voice_tts/bin/train_tts.py +++ b/mozilla_voice_tts/bin/train_tts.py @@ -508,7 +508,7 @@ def main(args): # pylint: disable=redefined-outer-name prev_out_path = os.path.dirname(args.restore_path) speaker_mapping = load_speaker_mapping(prev_out_path) if not speaker_mapping: - print("WARNING: speakers.json speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file") + print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file") speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) if not speaker_mapping: raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file") @@ -559,8 +559,6 @@ def main(args): # pylint: disable=redefined-outer-name # setup criterion criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4) - for name, _ in model.named_parameters(): - print(name) if args.restore_path: checkpoint = torch.load(args.restore_path, map_location='cpu') @@ -575,6 +573,8 @@ def main(args): # pylint: disable=redefined-outer-name print(" > Partial model initialization.") model_dict = model.state_dict() model_dict = set_init_dict(model_dict, checkpoint['model'], c) + # torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt')) + # print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt')) model.load_state_dict(model_dict) del model_dict diff --git a/mozilla_voice_tts/tts/layers/gst_layers.py b/mozilla_voice_tts/tts/layers/gst_layers.py index 01f90697..a49b14a2 100644 --- a/mozilla_voice_tts/tts/layers/gst_layers.py +++ b/mozilla_voice_tts/tts/layers/gst_layers.py @@ -96,7 +96,7 @@ class StyleTokenLayer(nn.Module): self.key_dim = embedding_dim // num_heads self.style_tokens = nn.Parameter( torch.FloatTensor(num_style_tokens, self.key_dim)) - nn.init.orthogonal_(self.style_tokens) + nn.init.normal_(self.style_tokens, mean=0, std=0.5) self.attention = MultiHeadAttention( query_dim=self.query_dim, key_dim=self.key_dim,