bug fix argparser

This commit is contained in:
Eren Golge 2019-10-31 15:13:39 +01:00
parent 9b6021318d
commit ec579d02a1
3 changed files with 10 additions and 3 deletions

View File

@ -1,5 +1,5 @@
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py # edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
import os import os, sys
import math import math
import time import time
import subprocess import subprocess
@ -130,7 +130,7 @@ def main():
type=str, type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='', default='',
required=True) required='--config_path' not in sys.argv)
parser.add_argument( parser.add_argument(
'--restore_path', '--restore_path',
type=str, type=str,

View File

@ -226,6 +226,13 @@ class Attention(nn.Module):
return alpha return alpha
def forward(self, query, inputs, processed_inputs, mask): def forward(self, query, inputs, processed_inputs, mask):
"""
shapes:
query: B x D_attn_rnn
inputs: B x T_en x D_en
processed_inputs:: B x T_en x D_attn
mask: B x T_en
"""
if self.location_attention: if self.location_attention:
attention, _ = self.get_location_attention( attention, _ = self.get_location_attention(
query, processed_inputs) query, processed_inputs)

View File

@ -649,7 +649,7 @@ if __name__ == '__main__':
type=str, type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='', default='',
required=True) required='--config_path' not in sys.argv)
parser.add_argument( parser.add_argument(
'--restore_path', '--restore_path',
type=str, type=str,