Skip to content

Commit

Permalink
fix: send to default device pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Dec 14, 2023
1 parent 2dc9a6d commit e549390
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/train_test/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def simple_train(model: BaseModel, optimizer:torch.optim.Optimizer,
for _ in range(epochs):
# Training loop
for data in train_loader:
batch_pro = data['protein']
batch_mol = data['ligand']
labels = data['y'].reshape(-1,1)
batch_pro = data['protein'].device
batch_mol = data['ligand'].device
labels = data['y'].reshape(-1,1).device

if device is not None:
batch_pro = batch_pro.to(device)
Expand Down Expand Up @@ -87,9 +87,9 @@ def simple_eval(model:BaseModel, data_loader:DataLoader, device:torch.device=Non

with torch.no_grad():
for data in data_loader:
batch_pro = data['protein']
batch_mol = data['ligand']
labels = data['y'].reshape(-1,1)
batch_pro = data['protein'].device
batch_mol = data['ligand'].device
labels = data['y'].reshape(-1,1).device

if device is not None:
batch_pro = batch_pro.to(device)
Expand Down

0 comments on commit e549390

Please sign in to comment.