mirror of https://github.com/coqui-ai/TTS.git
save used model characters to the checkpoints
This commit is contained in:
parent
8ec28b1ac2
commit
62aeacbdd1
|
@ -268,7 +268,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
||||||
if global_step % c.save_step == 0:
|
if global_step % c.save_step == 0:
|
||||||
if c.checkpoint:
|
if c.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
|
||||||
model_loss=loss_dict['loss'])
|
model_loss=loss_dict['loss'])
|
||||||
|
|
||||||
# wait all kernels to be completed
|
# wait all kernels to be completed
|
||||||
|
@ -467,7 +467,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||||
|
|
||||||
def main(args): # pylint: disable=redefined-outer-name
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=global-variable-undefined
|
# pylint: disable=global-variable-undefined
|
||||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||||
# Audio processor
|
# Audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
if 'characters' in c.keys():
|
if 'characters' in c.keys():
|
||||||
|
@ -477,7 +477,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
init_distributed(args.rank, num_gpus, args.group_id,
|
init_distributed(args.rank, num_gpus, args.group_id,
|
||||||
c.distributed["backend"], c.distributed["url"])
|
c.distributed["backend"], c.distributed["url"])
|
||||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
|
||||||
|
# set model characters
|
||||||
|
model_characters = phonemes if c.use_phonemes else symbols
|
||||||
|
num_chars = len(model_characters)
|
||||||
|
|
||||||
# load data instances
|
# load data instances
|
||||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||||
|
@ -559,7 +562,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
target_loss = eval_avg_loss_dict['avg_loss']
|
target_loss = eval_avg_loss_dict['avg_loss']
|
||||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||||
OUT_PATH)
|
OUT_PATH, model_characters)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -247,7 +247,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
||||||
if global_step % c.save_step == 0:
|
if global_step % c.save_step == 0:
|
||||||
if c.checkpoint:
|
if c.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
|
||||||
model_loss=loss_dict['loss'])
|
model_loss=loss_dict['loss'])
|
||||||
|
|
||||||
# wait all kernels to be completed
|
# wait all kernels to be completed
|
||||||
|
@ -431,7 +431,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||||
# FIXME: move args definition/parsing inside of main?
|
# FIXME: move args definition/parsing inside of main?
|
||||||
def main(args): # pylint: disable=redefined-outer-name
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=global-variable-undefined
|
# pylint: disable=global-variable-undefined
|
||||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||||
# Audio processor
|
# Audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
if 'characters' in c.keys():
|
if 'characters' in c.keys():
|
||||||
|
@ -441,7 +441,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
init_distributed(args.rank, num_gpus, args.group_id,
|
init_distributed(args.rank, num_gpus, args.group_id,
|
||||||
c.distributed["backend"], c.distributed["url"])
|
c.distributed["backend"], c.distributed["url"])
|
||||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
|
||||||
|
# set model characters
|
||||||
|
model_characters = phonemes if c.use_phonemes else symbols
|
||||||
|
num_chars = len(model_characters)
|
||||||
|
|
||||||
# load data instances
|
# load data instances
|
||||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
|
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
|
||||||
|
@ -523,7 +526,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
target_loss = eval_avg_loss_dict['avg_loss']
|
target_loss = eval_avg_loss_dict['avg_loss']
|
||||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
|
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
|
||||||
global_step, epoch, c.r,
|
global_step, epoch, c.r,
|
||||||
OUT_PATH)
|
OUT_PATH, model_characters)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -284,6 +284,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
||||||
optimizer_st=optimizer_st,
|
optimizer_st=optimizer_st,
|
||||||
model_loss=loss_dict['postnet_loss'],
|
model_loss=loss_dict['postnet_loss'],
|
||||||
|
characters=model_characters,
|
||||||
scaler=scaler.state_dict() if c.mixed_precision else None)
|
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
|
@ -492,9 +493,11 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||||
|
|
||||||
def main(args): # pylint: disable=redefined-outer-name
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=global-variable-undefined
|
# pylint: disable=global-variable-undefined
|
||||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
global meta_data_train, meta_data_eval, speaker_mapping, symbols, phonemes, model_characters
|
||||||
# Audio processor
|
# Audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
|
|
||||||
|
# setup custom characters if set in config file.
|
||||||
if 'characters' in c.keys():
|
if 'characters' in c.keys():
|
||||||
symbols, phonemes = make_symbols(**c.characters)
|
symbols, phonemes = make_symbols(**c.characters)
|
||||||
|
|
||||||
|
@ -503,6 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
init_distributed(args.rank, num_gpus, args.group_id,
|
init_distributed(args.rank, num_gpus, args.group_id,
|
||||||
c.distributed["backend"], c.distributed["url"])
|
c.distributed["backend"], c.distributed["url"])
|
||||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||||
|
model_characters = phonemes if c.use_phonemes else symbols
|
||||||
|
|
||||||
# load data instances
|
# load data instances
|
||||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||||
|
@ -634,6 +638,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
epoch,
|
epoch,
|
||||||
c.r,
|
c.r,
|
||||||
OUT_PATH,
|
OUT_PATH,
|
||||||
|
model_characters,
|
||||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue