From e774f68aeefd9dfac5a09847cb8def93a5e22184 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 12 Feb 2021 12:03:42 +0000 Subject: [PATCH] save used model characters to the checkpoints --- TTS/bin/train_glow_tts.py | 11 +++++++---- TTS/bin/train_speedy_speech.py | 11 +++++++---- TTS/bin/train_tacotron.py | 7 ++++++- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 9db2381e..a12c5581 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -268,7 +268,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, + save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters, model_loss=loss_dict['loss']) # wait all kernels to be completed @@ -467,7 +467,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined - global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping + global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping # Audio processor ap = AudioProcessor(**c.audio) if 'characters' in c.keys(): @@ -477,7 +477,10 @@ def main(args): # pylint: disable=redefined-outer-name if num_gpus > 1: init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) - num_chars = len(phonemes) if c.use_phonemes else len(symbols) + + # set model characters + model_characters = phonemes if c.use_phonemes else symbols + num_chars = len(model_characters) # load data instances meta_data_train, meta_data_eval = load_meta_data(c.datasets) @@ -559,7 +562,7 @@ def main(args): # pylint: disable=redefined-outer-name if c.run_eval: target_loss = eval_avg_loss_dict['avg_loss'] best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, - OUT_PATH) + OUT_PATH, model_characters) if __name__ == '__main__': diff --git a/TTS/bin/train_speedy_speech.py b/TTS/bin/train_speedy_speech.py index a9a83bbf..1f32c8f6 100644 --- a/TTS/bin/train_speedy_speech.py +++ b/TTS/bin/train_speedy_speech.py @@ -247,7 +247,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, + save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters, model_loss=loss_dict['loss']) # wait all kernels to be completed @@ -431,7 +431,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined - global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping + global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping # Audio processor ap = AudioProcessor(**c.audio) if 'characters' in c.keys(): @@ -441,7 +441,10 @@ def main(args): # pylint: disable=redefined-outer-name if num_gpus > 1: init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) - num_chars = len(phonemes) if c.use_phonemes else len(symbols) + + # set model characters + model_characters = phonemes if c.use_phonemes else symbols + num_chars = len(model_characters) # load data instances meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True) @@ -523,7 +526,7 @@ def main(args): # pylint: disable=redefined-outer-name target_loss = eval_avg_loss_dict['avg_loss'] best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, - OUT_PATH) + OUT_PATH, model_characters) if __name__ == '__main__': diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index 0a53f2a1..a9c0881f 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -284,6 +284,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, optimizer_st=optimizer_st, model_loss=loss_dict['postnet_loss'], + characters=model_characters, scaler=scaler.state_dict() if c.mixed_precision else None) # Diagnostic visualizations @@ -492,9 +493,11 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined - global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping + global meta_data_train, meta_data_eval, speaker_mapping, symbols, phonemes, model_characters # Audio processor ap = AudioProcessor(**c.audio) + + # setup custom characters if set in config file. if 'characters' in c.keys(): symbols, phonemes = make_symbols(**c.characters) @@ -503,6 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) num_chars = len(phonemes) if c.use_phonemes else len(symbols) + model_characters = phonemes if c.use_phonemes else symbols # load data instances meta_data_train, meta_data_eval = load_meta_data(c.datasets) @@ -634,6 +638,7 @@ def main(args): # pylint: disable=redefined-outer-name epoch, c.r, OUT_PATH, + model_characters, scaler=scaler.state_dict() if c.mixed_precision else None )