From 379ff17a9a81542d36568b6369cc74f4c046c589 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 26 Apr 2018 05:35:25 -0700 Subject: [PATCH] optional proirity_freq --- config.json | 1 + train.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/config.json b/config.json index d43c16a0..3eb74365 100644 --- a/config.json +++ b/config.json @@ -17,6 +17,7 @@ "eval_batch_size":32, "r": 5, "mk": 1.0, + "priority_freq": false, "griffin_lim_iters": 60, "power": 1.2, diff --git a/train.py b/train.py index 75a69081..7e41712a 100644 --- a/train.py +++ b/train.py @@ -54,6 +54,12 @@ pickle.dump(c, open(tmp_path, "wb")) LOG_DIR = OUT_PATH tb = SummaryWriter(LOG_DIR) +if c.priority_freq: + n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) + print(" > Using num priority freq. : {}".format(n_priority_freq)) +else: + print(" > Priority freq. is disabled.") + def signal_handler(signal, frame): """Ctrl+C handler to remove empty experiment folder""" @@ -71,7 +77,6 @@ def train(model, criterion, data_loader, optimizer, epoch): print(" | > Epoch {}/{}".format(epoch, c.epochs)) progbar = Progbar(len(data_loader.dataset) / c.batch_size) - n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) progbar_display = {} for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -214,7 +219,6 @@ def evaluate(model, criterion, data_loader, current_step): print("\n | > Validation") progbar = Progbar(len(data_loader.dataset) / c.batch_size) - n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) with torch.no_grad(): for num_iter, data in enumerate(data_loader): start_time = time.time()