diff --git a/train.py b/train.py index 314d3af3..0c3cff03 100644 --- a/train.py +++ b/train.py @@ -592,7 +592,7 @@ def main(args): # pylint: disable=redefined-outer-name args.restore_step = 0 if use_cuda: - model = model.cuda() + model.cuda() criterion.cuda() if criterion_st: criterion_st.cuda()