Skip to content

Commit

Permalink
fix single gpu training
Browse files Browse the repository at this point in the history
  • Loading branch information
djaniak committed Nov 30, 2023
1 parent 9da5d26 commit 90b406e
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions embeddings/model/lightning_module/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ def predict(

if return_predictions:
assert predictions is not None
logits, preds = zip(*predictions)
logits, preds, labels = zip(*predictions)
probabilities = softmax(torch.cat(logits), dim=1)
preds = torch.cat(preds)
labels = torch.cat([x["labels"] for x in dataloader])
labels = torch.cat(labels)
# labels = torch.cat([x["labels"] for x in dataloader])
else:
files = sorted(os.listdir(predpath))
all_preds = []
Expand Down

0 comments on commit 90b406e

Please sign in to comment.