fix tacotron for coqpit

This commit is contained in:
Eren Gölge 2021-05-05 02:30:25 +02:00
parent 65d7ad4250
commit eaa130e813
3 changed files with 38 additions and 36 deletions

View File

@ -90,7 +90,7 @@ def format_data(data):
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
speaker_names = data[2] speaker_names = data[2]
linear_input = data[3] if config.model in ["Tacotron"] else None linear_input = data[3] if config.model.lower() in ["tacotron"] else None
mel_input = data[4] mel_input = data[4]
mel_lengths = data[5] mel_lengths = data[5]
stop_targets = data[6] stop_targets = data[6]
@ -369,9 +369,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
# Sample audio # Sample audio
if config.model in ["Tacotron", "TacotronGST"]: if config.model in ["Tacotron", "TacotronGST"]:
train_audio = ap.inv_spectrogram(const_speconfig.T) train_audio = ap.inv_spectrogram(const_spec.T)
else: else:
train_audio = ap.inv_melspectrogram(const_speconfig.T) train_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"]) tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"])
end_time = time.time() end_time = time.time()
@ -507,10 +507,10 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
} }
# Sample audio # Sample audio
if config.model in ["Tacotron", "TacotronGST"]: if config.model.lower() in ["tacotron"]:
eval_audio = ap.inv_spectrogram(const_speconfig.T) eval_audio = ap.inv_spectrogram(const_spec.T)
else: else:
eval_audio = ap.inv_melspectrogram(const_speconfig.T) eval_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"]) tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"])
# Plot Validation Stats # Plot Validation Stats
@ -522,7 +522,10 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
tb_logger.tb_eval_figures(global_step, eval_figures) tb_logger.tb_eval_figures(global_step, eval_figures)
if args.rank == 0 and epoch > config.test_delay_epochs: if args.rank == 0 and epoch > config.test_delay_epochs:
if config.test_sentences_file is None: if config.test_sentences_file:
with open(config.test_sentences_file, "r") as f:
test_sentences = [s.strip() for s in f.readlines()]
else:
test_sentences = [ test_sentences = [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.", "Be a voice, not an echo.",
@ -530,9 +533,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
"This cake is great. It's so delicious and moist.", "This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.", "Prior to November 22, 1963.",
] ]
else:
with open(config.test_sentences_file, "r") as f:
test_sentences = [s.strip() for s in f.readlines()]
# test sentences # test sentences
test_audios = {} test_audios = {}
@ -544,14 +544,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
if config.use_external_speaker_embedding_file and config.use_speaker_embedding if config.use_external_speaker_embedding_file and config.use_speaker_embedding
else None else None
) )
style_wav = config.get("gst_style_input") style_wav = config.gst_style_input
if style_wav is None and config.use_gst: if style_wav is None and config.gst is not None:
# inicialize GST with zero dict. # inicialize GST with zero dict.
style_wav = {} style_wav = {}
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
for i in range(config.gst["gst_num_style_tokens"]): for i in range(config.gst["gst_num_style_tokens"]):
style_wav[str(i)] = 0 style_wav[str(i)] = 0
style_wav = config.get("gst_style_input")
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
try: try:
wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis( wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis(
@ -639,14 +638,10 @@ def main(args): # pylint: disable=redefined-outer-name
if "scaler" in checkpoint and config.mixed_precision: if "scaler" in checkpoint and config.mixed_precision:
print(" > Restoring AMP Scaler...") print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"]) scaler.load_state_dict(checkpoint["scaler"])
if config.reinit_layers:
raise RuntimeError
except (KeyError, RuntimeError): except (KeyError, RuntimeError):
print(" > Partial model initialization...") print(" > Partial model initialization...")
model_dict = model.state_dict() model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], c) model_dict = set_init_dict(model_dict, checkpoint["model"], c)
# torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
# print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
model.load_state_dict(model_dict) model.load_state_dict(model_dict)
del model_dict del model_dict
@ -743,12 +738,12 @@ if __name__ == "__main__":
try: try:
main(args) main(args)
except KeyboardInterrupt: except KeyboardInterrupt:
# remove_experiment_folder(OUT_PATH) remove_experiment_folder(OUT_PATH)
try: try:
sys.exit(0) sys.exit(0)
except SystemExit: except SystemExit:
os._exit(0) # pylint: disable=protected-access os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
# remove_experiment_folder(OUT_PATH) remove_experiment_folder(OUT_PATH)
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)

View File

@ -22,7 +22,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
r=c.r, r=c.r,
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1), postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
decoder_output_dim=c.audio["num_mels"], decoder_output_dim=c.audio["num_mels"],
gst=c.use_gst, gst=c.gst,
memory_size=c.memory_size, memory_size=c.memory_size,
attn_type=c.attention_type, attn_type=c.attention_type,
attn_win=c.windowing, attn_win=c.windowing,

View File

@ -65,28 +65,31 @@ def compute_style_mel(style_wav, ap, cuda=False):
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None): def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings
if "tacotron" in CONFIG.model.lower(): if "tacotron" in CONFIG.model.lower():
if not CONFIG.use_gst: if CONFIG.gst:
style_mel = None
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference( decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
) )
else:
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
elif "glow" in CONFIG.model.lower(): elif "glow" in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, "module"): if hasattr(model, "module"):
# distributed model # distributed model
postnet_output, _, _, _, alignments, _, _ = model.module.inference( postnet_output, _, _, _, alignments, _, _ = model.module.inference(
inputs, inputs_lengths, g=speaker_embedding_g inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
) )
else: else:
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_embedding_g) postnet_output, _, _, _, alignments, _, _ = model.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
)
postnet_output = postnet_output.permute(0, 2, 1) postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models. # these only belong to tacotron models.
decoder_output = None decoder_output = None
@ -95,9 +98,13 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, "module"): if hasattr(model, "module"):
# distributed model # distributed model
postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_embedding_g) postnet_output, alignments = model.module.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
)
else: else:
postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_embedding_g) postnet_output, alignments = model.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
)
postnet_output = postnet_output.permute(0, 2, 1) postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models. # these only belong to tacotron models.
decoder_output = None decoder_output = None
@ -108,7 +115,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
if CONFIG.use_gst and style_mel is not None: if CONFIG.gst and style_mel is not None:
raise NotImplementedError(" [!] GST inference not implemented for TF") raise NotImplementedError(" [!] GST inference not implemented for TF")
if truncated: if truncated:
raise NotImplementedError(" [!] Truncated inference not implemented for TF") raise NotImplementedError(" [!] Truncated inference not implemented for TF")
@ -120,7 +127,7 @@ def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=No
def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
if CONFIG.use_gst and style_mel is not None: if CONFIG.gst and style_mel is not None:
raise NotImplementedError(" [!] GST inference not implemented for TfLite") raise NotImplementedError(" [!] GST inference not implemented for TfLite")
if truncated: if truncated:
raise NotImplementedError(" [!] Truncated inference not implemented for TfLite") raise NotImplementedError(" [!] Truncated inference not implemented for TfLite")
@ -249,7 +256,7 @@ def synthesis(
""" """
# GST processing # GST processing
style_mel = None style_mel = None
if "use_gst" in CONFIG.keys() and CONFIG.use_gst and style_wav is not None: if CONFIG.has('gst') and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict): if isinstance(style_wav, dict):
style_mel = style_wav style_mel = style_wav
else: else: