Skip to content

Commit

Permalink
cUDA SETTINGS
Browse files Browse the repository at this point in the history
  • Loading branch information
Owyii committed Apr 13, 2024
1 parent 79fb198 commit 82fdc9e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions mACHINE-LEARNINGS/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def execute(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay):
device = torch.device('cuda')
else:
device = torch.device('cpu')

"""
DATA LOADING
- Load all data: train, test, validation
Expand Down Expand Up @@ -64,15 +64,15 @@ def execute(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay):
optimizer.zero_grad()
output = model(image.to(device))

label = label.unsqueeze(1)
label = label.unsqueeze(1).to(device)

loss = loss_function(output, label.float())
loss.backward()
optimizer.step()

train_loss += loss.item()

predictions = torch.where(output > .5, 1, 0)
predictions = torch.where(output > .5, 1, 0).to(device)
acc += (label == predictions).sum()/len(label)

if IS_VERBOSE:
Expand All @@ -93,12 +93,12 @@ def execute(train_set_size, batch_size, lr, epochs, is_verbose, weight_decay):
for batch_num, (image, label) in enumerate(dataloaders["test"]):
output = model(image.to(device))

label = label.unsqueeze(1)
label = label.unsqueeze(1).to(device)
loss = loss_function(output, label.float())

test_loss += loss.item()

predictions = torch.where(output > .5, 1, 0)
predictions = torch.where(output > .5, 1, 0).to(device)
acc += (label == predictions).sum()/len(label)

if IS_VERBOSE:
Expand Down

0 comments on commit 82fdc9e

Please sign in to comment.