mirror of https://github.com/coqui-ai/TTS.git
fix tacotron for coqpit
This commit is contained in:
parent
65d7ad4250
commit
eaa130e813
|
@ -90,7 +90,7 @@ def format_data(data):
|
|||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_names = data[2]
|
||||
linear_input = data[3] if config.model in ["Tacotron"] else None
|
||||
linear_input = data[3] if config.model.lower() in ["tacotron"] else None
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_targets = data[6]
|
||||
|
@ -369,9 +369,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
|
|||
|
||||
# Sample audio
|
||||
if config.model in ["Tacotron", "TacotronGST"]:
|
||||
train_audio = ap.inv_spectrogram(const_speconfig.T)
|
||||
train_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
train_audio = ap.inv_melspectrogram(const_speconfig.T)
|
||||
train_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
|
@ -507,10 +507,10 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
}
|
||||
|
||||
# Sample audio
|
||||
if config.model in ["Tacotron", "TacotronGST"]:
|
||||
eval_audio = ap.inv_spectrogram(const_speconfig.T)
|
||||
if config.model.lower() in ["tacotron"]:
|
||||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
eval_audio = ap.inv_melspectrogram(const_speconfig.T)
|
||||
eval_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
|
@ -522,7 +522,10 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
if args.rank == 0 and epoch > config.test_delay_epochs:
|
||||
if config.test_sentences_file is None:
|
||||
if config.test_sentences_file:
|
||||
with open(config.test_sentences_file, "r") as f:
|
||||
test_sentences = [s.strip() for s in f.readlines()]
|
||||
else:
|
||||
test_sentences = [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
|
@ -530,9 +533,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
else:
|
||||
with open(config.test_sentences_file, "r") as f:
|
||||
test_sentences = [s.strip() for s in f.readlines()]
|
||||
|
||||
# test sentences
|
||||
test_audios = {}
|
||||
|
@ -544,14 +544,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
if config.use_external_speaker_embedding_file and config.use_speaker_embedding
|
||||
else None
|
||||
)
|
||||
style_wav = config.get("gst_style_input")
|
||||
if style_wav is None and config.use_gst:
|
||||
style_wav = config.gst_style_input
|
||||
if style_wav is None and config.gst is not None:
|
||||
# inicialize GST with zero dict.
|
||||
style_wav = {}
|
||||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
||||
for i in range(config.gst["gst_num_style_tokens"]):
|
||||
style_wav[str(i)] = 0
|
||||
style_wav = config.get("gst_style_input")
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis(
|
||||
|
@ -639,14 +638,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if "scaler" in checkpoint and config.mixed_precision:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
scaler.load_state_dict(checkpoint["scaler"])
|
||||
if config.reinit_layers:
|
||||
raise RuntimeError
|
||||
except (KeyError, RuntimeError):
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
# torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
|
||||
# print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
|
@ -743,12 +738,12 @@ if __name__ == "__main__":
|
|||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
# remove_experiment_folder(OUT_PATH)
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# remove_experiment_folder(OUT_PATH)
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
|
|
@ -22,7 +22,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
r=c.r,
|
||||
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
gst=c.use_gst,
|
||||
gst=c.gst,
|
||||
memory_size=c.memory_size,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
|
|
|
@ -65,28 +65,31 @@ def compute_style_mel(style_wav, ap, cuda=False):
|
|||
|
||||
|
||||
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
|
||||
speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings
|
||||
if "tacotron" in CONFIG.model.lower():
|
||||
if not CONFIG.use_gst:
|
||||
style_mel = None
|
||||
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
if CONFIG.gst:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
elif "glow" in CONFIG.model.lower():
|
||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
if hasattr(model, "module"):
|
||||
# distributed model
|
||||
postnet_output, _, _, _, alignments, _, _ = model.module.inference(
|
||||
inputs, inputs_lengths, g=speaker_embedding_g
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
else:
|
||||
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_embedding_g)
|
||||
postnet_output, _, _, _, alignments, _, _ = model.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
decoder_output = None
|
||||
|
@ -95,9 +98,13 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
|
|||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
if hasattr(model, "module"):
|
||||
# distributed model
|
||||
postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_embedding_g)
|
||||
postnet_output, alignments = model.module.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
else:
|
||||
postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_embedding_g)
|
||||
postnet_output, alignments = model.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
decoder_output = None
|
||||
|
@ -108,7 +115,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
|
|||
|
||||
|
||||
def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
if CONFIG.use_gst and style_mel is not None:
|
||||
if CONFIG.gst and style_mel is not None:
|
||||
raise NotImplementedError(" [!] GST inference not implemented for TF")
|
||||
if truncated:
|
||||
raise NotImplementedError(" [!] Truncated inference not implemented for TF")
|
||||
|
@ -120,7 +127,7 @@ def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=No
|
|||
|
||||
|
||||
def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
if CONFIG.use_gst and style_mel is not None:
|
||||
if CONFIG.gst and style_mel is not None:
|
||||
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
|
||||
if truncated:
|
||||
raise NotImplementedError(" [!] Truncated inference not implemented for TfLite")
|
||||
|
@ -249,7 +256,7 @@ def synthesis(
|
|||
"""
|
||||
# GST processing
|
||||
style_mel = None
|
||||
if "use_gst" in CONFIG.keys() and CONFIG.use_gst and style_wav is not None:
|
||||
if CONFIG.has('gst') and CONFIG.gst and style_wav is not None:
|
||||
if isinstance(style_wav, dict):
|
||||
style_mel = style_wav
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue