Skip to content

Commit

Permalink
[FIX] incorrect variable name
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Feb 2, 2025
1 parent 1b630ba commit 9815bc0
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/scportrait/pipeline/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@ def _load_model(self):

model = ConvNextModel.from_pretrained("facebook/convnext-xlarge-224-22k")
model.eval()
model.to(self.device)
model.to(self.inference_device)

return model

Expand Down Expand Up @@ -1338,7 +1338,7 @@ def inference(self, dataloader, model_fun, partial=False):
batch_size = self.batch_size

images, label, class_id = next(data_iter)
images["pixel_values"] = images["pixel_values"][0].to(self.device)
images["pixel_values"] = images["pixel_values"][0].to(self.inference_device)
o = model_fun(**images)
result = o.pooler_output.cpu().detach()

Expand All @@ -1364,7 +1364,9 @@ def inference(self, dataloader, model_fun, partial=False):

for i in tqdm(range(len(dataloader) - 1)):
images, label, class_id = next(data_iter)
images["pixel_values"] = images["pixel_values"][0].to(self.device)
images["pixel_values"] = images["pixel_values"][0].to(
self.inference_device
)

o = model_fun(**images)
result = o.pooler_output.cpu().detach()
Expand Down

0 comments on commit 9815bc0

Please sign in to comment.