fix tacotron for coqpit

This commit is contained in:
Eren Gölge 2021-05-05 02:30:25 +02:00
parent 65d7ad4250
commit eaa130e813
3 changed files with 38 additions and 36 deletions

View File

@ -90,7 +90,7 @@ def format_data(data):
text_input = data[0]
text_lengths = data[1]
speaker_names = data[2]
linear_input = data[3] if config.model in ["Tacotron"] else None
linear_input = data[3] if config.model.lower() in ["tacotron"] else None
mel_input = data[4]
mel_lengths = data[5]
stop_targets = data[6]
@ -369,9 +369,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
# Sample audio
if config.model in ["Tacotron", "TacotronGST"]:
train_audio = ap.inv_spectrogram(const_speconfig.T)
train_audio = ap.inv_spectrogram(const_spec.T)
else:
train_audio = ap.inv_melspectrogram(const_speconfig.T)
train_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"])
end_time = time.time()
@ -507,10 +507,10 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
}
# Sample audio
if config.model in ["Tacotron", "TacotronGST"]:
eval_audio = ap.inv_spectrogram(const_speconfig.T)
if config.model.lower() in ["tacotron"]:
eval_audio = ap.inv_spectrogram(const_spec.T)
else:
eval_audio = ap.inv_melspectrogram(const_speconfig.T)
eval_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"])
# Plot Validation Stats
@ -522,7 +522,10 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
tb_logger.tb_eval_figures(global_step, eval_figures)
if args.rank == 0 and epoch > config.test_delay_epochs:
if config.test_sentences_file is None:
if config.test_sentences_file:
with open(config.test_sentences_file, "r") as f:
test_sentences = [s.strip() for s in f.readlines()]
else:
test_sentences = [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
@ -530,9 +533,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
else:
with open(config.test_sentences_file, "r") as f:
test_sentences = [s.strip() for s in f.readlines()]
# test sentences
test_audios = {}
@ -544,14 +544,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
if config.use_external_speaker_embedding_file and config.use_speaker_embedding
else None
)
style_wav = config.get("gst_style_input")
if style_wav is None and config.use_gst:
style_wav = config.gst_style_input
if style_wav is None and config.gst is not None:
# inicialize GST with zero dict.
style_wav = {}
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
for i in range(config.gst["gst_num_style_tokens"]):
style_wav[str(i)] = 0
style_wav = config.get("gst_style_input")
for idx, test_sentence in enumerate(test_sentences):
try:
wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis(
@ -639,14 +638,10 @@ def main(args): # pylint: disable=redefined-outer-name
if "scaler" in checkpoint and config.mixed_precision:
print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"])
if config.reinit_layers:
raise RuntimeError
except (KeyError, RuntimeError):
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
# torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
# print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
model.load_state_dict(model_dict)
del model_dict
@ -743,12 +738,12 @@ if __name__ == "__main__":
try:
main(args)
except KeyboardInterrupt:
# remove_experiment_folder(OUT_PATH)
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
# remove_experiment_folder(OUT_PATH)
remove_experiment_folder(OUT_PATH)
traceback.print_exc()
sys.exit(1)

View File

@ -22,7 +22,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
r=c.r,
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
decoder_output_dim=c.audio["num_mels"],
gst=c.use_gst,
gst=c.gst,
memory_size=c.memory_size,
attn_type=c.attention_type,
attn_win=c.windowing,

View File

@ -65,28 +65,31 @@ def compute_style_mel(style_wav, ap, cuda=False):
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings
if "tacotron" in CONFIG.model.lower():
if not CONFIG.use_gst:
style_mel = None
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
else:
if CONFIG.gst:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
else:
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
elif "glow" in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, "module"):
# distributed model
postnet_output, _, _, _, alignments, _, _ = model.module.inference(
inputs, inputs_lengths, g=speaker_embedding_g
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
)
else:
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_embedding_g)
postnet_output, _, _, _, alignments, _, _ = model.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None
@ -95,9 +98,13 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, "module"):
# distributed model
postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_embedding_g)
postnet_output, alignments = model.module.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
)
else:
postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_embedding_g)
postnet_output, alignments = model.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None
@ -108,7 +115,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
if CONFIG.use_gst and style_mel is not None:
if CONFIG.gst and style_mel is not None:
raise NotImplementedError(" [!] GST inference not implemented for TF")
if truncated:
raise NotImplementedError(" [!] Truncated inference not implemented for TF")
@ -120,7 +127,7 @@ def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=No
def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
if CONFIG.use_gst and style_mel is not None:
if CONFIG.gst and style_mel is not None:
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
if truncated:
raise NotImplementedError(" [!] Truncated inference not implemented for TfLite")
@ -249,7 +256,7 @@ def synthesis(
"""
# GST processing
style_mel = None
if "use_gst" in CONFIG.keys() and CONFIG.use_gst and style_wav is not None:
if CONFIG.has('gst') and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else: