From cef70923d18eff30d444c2600f1c916163c6a6df Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 27 Feb 2018 07:32:09 -0800 Subject: [PATCH] more tests --- tests/generic_utils_text.py | 38 +++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/generic_utils_text.py diff --git a/tests/generic_utils_text.py b/tests/generic_utils_text.py new file mode 100644 index 00000000..0461d263 --- /dev/null +++ b/tests/generic_utils_text.py @@ -0,0 +1,38 @@ +import unittest +import torch as T + +from TTS.utils.generic_utils import save_checkpoint, save_best_model +from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder + +OUT_PATH = '/tmp/test.pth.tar' + +class ModelSavingTests(unittest.TestCase): + + def save_checkpoint_test(self): + # create a dummy model + model = Prenet(128, out_features=[256, 128]) + model = T.nn.DataParallel(layer) + + # save the model + save_checkpoint(model, None, 100, + OUTPATH, 1, 1) + + # load the model to CPU + model_dict = torch.load(MODEL_PATH, map_location=lambda storage, + loc: storage) + model.load_state_dict(model_dict['model']) + + def save_best_model_test(self): + # create a dummy model + model = Prenet(256, out_features=[256, 256]) + model = T.nn.DataParallel(layer) + + # save the model + best_loss = save_best_model(model, None, 0, + 100, OUT_PATH, + 10, 1) + + # load the model to CPU + model_dict = torch.load(MODEL_PATH, map_location=lambda storage, + loc: storage) + model.load_state_dict(model_dict['model'])