diff --git a/src/train_test/simple.py b/src/train_test/simple.py index 21631ccd..7a8ebb1a 100644 --- a/src/train_test/simple.py +++ b/src/train_test/simple.py @@ -38,9 +38,9 @@ def simple_train(model: BaseModel, optimizer:torch.optim.Optimizer, for _ in range(epochs): # Training loop for data in train_loader: - batch_pro = data['protein'] - batch_mol = data['ligand'] - labels = data['y'].reshape(-1,1) + batch_pro = data['protein'].device + batch_mol = data['ligand'].device + labels = data['y'].reshape(-1,1).device if device is not None: batch_pro = batch_pro.to(device) @@ -87,9 +87,9 @@ def simple_eval(model:BaseModel, data_loader:DataLoader, device:torch.device=Non with torch.no_grad(): for data in data_loader: - batch_pro = data['protein'] - batch_mol = data['ligand'] - labels = data['y'].reshape(-1,1) + batch_pro = data['protein'].device + batch_mol = data['ligand'].device + labels = data['y'].reshape(-1,1).device if device is not None: batch_pro = batch_pro.to(device)