From 9815bc068ca8230cf0eb0bf6f5b44620bd221469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sun, 2 Feb 2025 11:49:34 +0100 Subject: [PATCH] [FIX] incorrect variable name --- src/scportrait/pipeline/classification.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/scportrait/pipeline/classification.py b/src/scportrait/pipeline/classification.py index 594c6b0..71f19d5 100644 --- a/src/scportrait/pipeline/classification.py +++ b/src/scportrait/pipeline/classification.py @@ -1201,7 +1201,7 @@ def _load_model(self): model = ConvNextModel.from_pretrained("facebook/convnext-xlarge-224-22k") model.eval() - model.to(self.device) + model.to(self.inference_device) return model @@ -1338,7 +1338,7 @@ def inference(self, dataloader, model_fun, partial=False): batch_size = self.batch_size images, label, class_id = next(data_iter) - images["pixel_values"] = images["pixel_values"][0].to(self.device) + images["pixel_values"] = images["pixel_values"][0].to(self.inference_device) o = model_fun(**images) result = o.pooler_output.cpu().detach() @@ -1364,7 +1364,9 @@ def inference(self, dataloader, model_fun, partial=False): for i in tqdm(range(len(dataloader) - 1)): images, label, class_id = next(data_iter) - images["pixel_values"] = images["pixel_values"][0].to(self.device) + images["pixel_values"] = images["pixel_values"][0].to( + self.inference_device + ) o = model_fun(**images) result = o.pooler_output.cpu().detach()