diff --git a/train.py b/train.py index b844024a..f2490b4f 100644 --- a/train.py +++ b/train.py @@ -255,12 +255,16 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch): avg_decoder_loss = 0 avg_stop_loss = 0 print("\n > Validation") - test_sentences = [ - "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", - "Be a voice, not an echo.", - "I'm sorry Dave. I'm afraid I can't do that.", - "This cake is great. It's so delicious and moist." - ] + if c.test_sentences_file is None: + test_sentences = [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist." + ] + else: + with open(c.test_sentences_file, "r") as f: + test_sentences = [s.strip() for s in f.readlines()] with torch.no_grad(): if data_loader is not None: for num_iter, data in enumerate(data_loader):