diff --git a/data/gan.py b/data/gan.py index f5705ea..e0f97b1 100644 --- a/data/gan.py +++ b/data/gan.py @@ -61,16 +61,19 @@ class GNet : self.logs = {} # self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu'] - self.GPU_CHIPS = None if 'gpu' not in args else args['gpu'] - if self.GPU_CHIPS is None: - self.GPU_CHIPS = [0] - if 'CUDA_VISIBLE_DEVICES' in os.environ : - os.environ.pop('CUDA_VISIBLE_DEVICES') - self.NUM_GPUS = 0 - else: - self.NUM_GPUS = len(self.GPU_CHIPS) + # self.GPU_CHIPS = None if 'gpu' not in args else args['gpu'] + # if self.GPU_CHIPS is None: + # self.GPU_CHIPS = [0] + # if 'CUDA_VISIBLE_DEVICES' in os.environ : + # os.environ.pop('CUDA_VISIBLE_DEVICES') + # self.NUM_GPUS = 0 + # else: + # self.NUM_GPUS = len(self.GPU_CHIPS) # os.environ['CUDA_VISIBLE_DEVICES'] = str(self.GPU_CHIPS[0]) - + self.NUM_GPUS = 0 if 'gpu' not in args else args['gpu'] + self.GPU_CHIPS = None if self.NUM_GPUS == 0 else [args['gpu']] + if self.GPU_CHIPS : + os.environ['CUDA_VISIBLE_DEVICES'] = str(self.GPU_CHIPS[0]) self.PARTITION = args['partition'] if 'partition' in args else None # if self.NUM_GPUS > 1 : # os.environ['CUDA_VISIBLE_DEVICES'] = "4"