From eaa130e813cf2f5c05e5664e331cbe4f723c64b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 5 May 2021 02:30:25 +0200 Subject: [PATCH] fix tacotron for coqpit --- TTS/bin/train_tacotron.py | 33 ++++++++++++---------------- TTS/tts/utils/generic_utils.py | 2 +- TTS/tts/utils/synthesis.py | 39 ++++++++++++++++++++-------------- 3 files changed, 38 insertions(+), 36 deletions(-) diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index f5d74099..edf89858 100755 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -90,7 +90,7 @@ def format_data(data): text_input = data[0] text_lengths = data[1] speaker_names = data[2] - linear_input = data[3] if config.model in ["Tacotron"] else None + linear_input = data[3] if config.model.lower() in ["tacotron"] else None mel_input = data[4] mel_lengths = data[5] stop_targets = data[6] @@ -369,9 +369,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap, # Sample audio if config.model in ["Tacotron", "TacotronGST"]: - train_audio = ap.inv_spectrogram(const_speconfig.T) + train_audio = ap.inv_spectrogram(const_spec.T) else: - train_audio = ap.inv_melspectrogram(const_speconfig.T) + train_audio = ap.inv_melspectrogram(const_spec.T) tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"]) end_time = time.time() @@ -507,10 +507,10 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): } # Sample audio - if config.model in ["Tacotron", "TacotronGST"]: - eval_audio = ap.inv_spectrogram(const_speconfig.T) + if config.model.lower() in ["tacotron"]: + eval_audio = ap.inv_spectrogram(const_spec.T) else: - eval_audio = ap.inv_melspectrogram(const_speconfig.T) + eval_audio = ap.inv_melspectrogram(const_spec.T) tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"]) # Plot Validation Stats @@ -522,7 +522,10 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): tb_logger.tb_eval_figures(global_step, eval_figures) if args.rank == 0 and epoch > config.test_delay_epochs: - if config.test_sentences_file is None: + if config.test_sentences_file: + with open(config.test_sentences_file, "r") as f: + test_sentences = [s.strip() for s in f.readlines()] + else: test_sentences = [ "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "Be a voice, not an echo.", @@ -530,9 +533,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): "This cake is great. It's so delicious and moist.", "Prior to November 22, 1963.", ] - else: - with open(config.test_sentences_file, "r") as f: - test_sentences = [s.strip() for s in f.readlines()] # test sentences test_audios = {} @@ -544,14 +544,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): if config.use_external_speaker_embedding_file and config.use_speaker_embedding else None ) - style_wav = config.get("gst_style_input") - if style_wav is None and config.use_gst: + style_wav = config.gst_style_input + if style_wav is None and config.gst is not None: # inicialize GST with zero dict. style_wav = {} print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") for i in range(config.gst["gst_num_style_tokens"]): style_wav[str(i)] = 0 - style_wav = config.get("gst_style_input") for idx, test_sentence in enumerate(test_sentences): try: wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis( @@ -639,14 +638,10 @@ def main(args): # pylint: disable=redefined-outer-name if "scaler" in checkpoint and config.mixed_precision: print(" > Restoring AMP Scaler...") scaler.load_state_dict(checkpoint["scaler"]) - if config.reinit_layers: - raise RuntimeError except (KeyError, RuntimeError): print(" > Partial model initialization...") model_dict = model.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model"], c) - # torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt')) - # print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt')) model.load_state_dict(model_dict) del model_dict @@ -743,12 +738,12 @@ if __name__ == "__main__": try: main(args) except KeyboardInterrupt: - # remove_experiment_folder(OUT_PATH) + remove_experiment_folder(OUT_PATH) try: sys.exit(0) except SystemExit: os._exit(0) # pylint: disable=protected-access except Exception: # pylint: disable=broad-except - # remove_experiment_folder(OUT_PATH) + remove_experiment_folder(OUT_PATH) traceback.print_exc() sys.exit(1) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index e6934bc9..8667b2ec 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -22,7 +22,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): r=c.r, postnet_output_dim=int(c.audio["fft_size"] / 2 + 1), decoder_output_dim=c.audio["num_mels"], - gst=c.use_gst, + gst=c.gst, memory_size=c.memory_size, attn_type=c.attention_type, attn_win=c.windowing, diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index f2cfbd43..281ef55d 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -65,28 +65,31 @@ def compute_style_mel(style_wav, ap, cuda=False): def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None): - speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings if "tacotron" in CONFIG.model.lower(): - if not CONFIG.use_gst: - style_mel = None - - if truncated: - decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings - ) - else: + if CONFIG.gst: decoder_output, postnet_output, alignments, stop_tokens = model.inference( inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings ) + else: + if truncated: + decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings + ) + else: + decoder_output, postnet_output, alignments, stop_tokens = model.inference( + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings + ) elif "glow" in CONFIG.model.lower(): inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable if hasattr(model, "module"): # distributed model postnet_output, _, _, _, alignments, _, _ = model.module.inference( - inputs, inputs_lengths, g=speaker_embedding_g + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings ) else: - postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_embedding_g) + postnet_output, _, _, _, alignments, _, _ = model.inference( + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings + ) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. decoder_output = None @@ -95,9 +98,13 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable if hasattr(model, "module"): # distributed model - postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_embedding_g) + postnet_output, alignments = model.module.inference( + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings + ) else: - postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_embedding_g) + postnet_output, alignments = model.inference( + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings + ) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. decoder_output = None @@ -108,7 +115,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): - if CONFIG.use_gst and style_mel is not None: + if CONFIG.gst and style_mel is not None: raise NotImplementedError(" [!] GST inference not implemented for TF") if truncated: raise NotImplementedError(" [!] Truncated inference not implemented for TF") @@ -120,7 +127,7 @@ def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=No def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): - if CONFIG.use_gst and style_mel is not None: + if CONFIG.gst and style_mel is not None: raise NotImplementedError(" [!] GST inference not implemented for TfLite") if truncated: raise NotImplementedError(" [!] Truncated inference not implemented for TfLite") @@ -249,7 +256,7 @@ def synthesis( """ # GST processing style_mel = None - if "use_gst" in CONFIG.keys() and CONFIG.use_gst and style_wav is not None: + if CONFIG.has('gst') and CONFIG.gst and style_wav is not None: if isinstance(style_wav, dict): style_mel = style_wav else: