Skip to content

Commit

Permalink
enable multi gpu support
Browse files Browse the repository at this point in the history
  • Loading branch information
mchaker authored Jul 10, 2022
1 parent 9cf6793 commit 6bb55e9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions multiligand_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_default_args(args, cmdline_args):
return args

def load_rec_and_model(args):
device = torch.device("cuda:0" if torch.cuda.is_available() and args.device == 'cuda' else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() and args.device == 'cuda' else "cpu")
print(f"device = {device}")
# sys.exit()
checkpoint = torch.load(args.checkpoint, map_location=device)
Expand Down Expand Up @@ -272,4 +272,4 @@ def main(arglist = None):
write_while_inferring(lig_loader, model, args)

if __name__ == '__main__':
main()
main()

0 comments on commit 6bb55e9

Please sign in to comment.