From e54939040f1489a0bffd78672f27980ade55542d Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Thu, 14 Dec 2023 14:49:38 -0500 Subject: [PATCH] fix: send to default device pytorch --- src/train_test/simple.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/train_test/simple.py b/src/train_test/simple.py index 21631ccd..7a8ebb1a 100644 --- a/src/train_test/simple.py +++ b/src/train_test/simple.py @@ -38,9 +38,9 @@ def simple_train(model: BaseModel, optimizer:torch.optim.Optimizer, for _ in range(epochs): # Training loop for data in train_loader: - batch_pro = data['protein'] - batch_mol = data['ligand'] - labels = data['y'].reshape(-1,1) + batch_pro = data['protein'].device + batch_mol = data['ligand'].device + labels = data['y'].reshape(-1,1).device if device is not None: batch_pro = batch_pro.to(device) @@ -87,9 +87,9 @@ def simple_eval(model:BaseModel, data_loader:DataLoader, device:torch.device=Non with torch.no_grad(): for data in data_loader: - batch_pro = data['protein'] - batch_mol = data['ligand'] - labels = data['y'].reshape(-1,1) + batch_pro = data['protein'].device + batch_mol = data['ligand'].device + labels = data['y'].reshape(-1,1).device if device is not None: batch_pro = batch_pro.to(device)