-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathGenerate predictions
41 lines (34 loc) · 1.61 KB
/
Generate predictions
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def generate_predictions(self,
loader: DataLoader,
device: str,
save_path: Optional[str]) -> BrainyPredictions:
"""
Generates projection embeddings and save them.
:param save_path: A path to the file to save the predictions.
:param loader: Data loader.
:param device: Device
:return: BrainyPredictions object.
"""
self.eval()
predictions: BrainyPredictions = BrainyPredictions()
with torch.no_grad():
for batch_idx, (brain_data, images, labels, img_idx, participants) in enumerate(loader):
brain_data, images, labels = brain_data.to(device), images.to(device), labels.to(
device)
brain_embeddings = self.encode_brain(brain_data)
image_embeddings = self.encode_image(images)
brainy_predictions, true_labels = get_brainy_predictions(
brain_embeddings=brain_embeddings,
image_embeddings=image_embeddings,
labels=labels,
top_k=1)
predictions.fill(
target=true_labels,
preds_brainy=brainy_predictions,
preds_image=self.image_model_frozen(images).tolist(),
preds_brain=self.brain_model_frozen(brain_data).tolist(),
image_idx=img_idx.tolist(),
participant=participants.tolist()
)
predictions.save(file_path=save_path)
return predictions