Skip to content

Commit 76b8535

Browse files
committed
add gpu id config
1 parent 0d2daab commit 76b8535

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

Diff for: cifar.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@
7272
parser.add_argument('--manualSeed', type=int, help='manual seed')
7373
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
7474
help='evaluate model on validation set')
75+
#Device options
76+
parser.add_argument('--gpu_id', default='1', type=str,
77+
help='id(s) for CUDA_VISIBLE_DEVICES')
7578

7679
args = parser.parse_args()
7780
state = {k: v for k, v in args._get_kwargs()}
@@ -80,6 +83,7 @@
8083
assert args.dataset == 'cifar10' or args.dataset == 'cifar100', 'Dataset can only be cifar10 or cifar100.'
8184

8285
# Use CUDA
86+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
8387
use_cuda = torch.cuda.is_available()
8488

8589
# Random seed
@@ -343,4 +347,4 @@ def adjust_learning_rate(optimizer, epoch):
343347
param_group['lr'] = state['lr']
344348

345349
if __name__ == '__main__':
346-
main()
350+
main()

Diff for: imagenet.py

+4
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,15 @@
8686
help='evaluate model on validation set')
8787
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
8888
help='use pre-trained model')
89+
#Device options
90+
parser.add_argument('--gpu_id', default='1', type=str,
91+
help='id(s) for CUDA_VISIBLE_DEVICES')
8992

9093
args = parser.parse_args()
9194
state = {k: v for k, v in args._get_kwargs()}
9295

9396
# Use CUDA
97+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
9498
use_cuda = torch.cuda.is_available()
9599

96100
# Random seed

0 commit comments

Comments
 (0)