From ec579d02a1c2536285f2f4e237382cb4feeb7a83 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 31 Oct 2019 15:13:39 +0100 Subject: [PATCH] bug fix argparser --- distribute.py | 4 ++-- layers/common_layers.py | 7 +++++++ train.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/distribute.py b/distribute.py index fe175617..a5fdb373 100644 --- a/distribute.py +++ b/distribute.py @@ -1,5 +1,5 @@ # edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py -import os +import os, sys import math import time import subprocess @@ -130,7 +130,7 @@ def main(): type=str, help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', default='', - required=True) + required='--config_path' not in sys.argv) parser.add_argument( '--restore_path', type=str, diff --git a/layers/common_layers.py b/layers/common_layers.py index d5836a9f..b9571546 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -226,6 +226,13 @@ class Attention(nn.Module): return alpha def forward(self, query, inputs, processed_inputs, mask): + """ + shapes: + query: B x D_attn_rnn + inputs: B x T_en x D_en + processed_inputs:: B x T_en x D_attn + mask: B x T_en + """ if self.location_attention: attention, _ = self.get_location_attention( query, processed_inputs) diff --git a/train.py b/train.py index 6f37442a..7590ad19 100644 --- a/train.py +++ b/train.py @@ -649,7 +649,7 @@ if __name__ == '__main__': type=str, help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', default='', - required=True) + required='--config_path' not in sys.argv) parser.add_argument( '--restore_path', type=str,