forked from saper0/revisiting_robustness
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
33 lines (25 loc) · 1.11 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from typing import Optional, Union
import numpy as np
import torch.nn as nn
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked
patch_typeguard()
@typechecked
def accuracy(logits: TensorType["n", "c"], labels: TensorType["n"],
split_idx: Optional[Union[np.ndarray, int]] = None) -> float:
"""Returns the accuracy for a tensor of logits and a list of lables.
Optionally, split indices can be given. Then, only the nodes in the split
will be used for the accuracy calculation.
Args:
logits (TensorType["n", "c"]): logits (`.argmax(1)` should return most
probable class).
labels (TensorType["n"]): target labels
split_idx (np.ndarray|int, optional): index or array with indices for
which accuracy should be evaluated. Defaults to None.
Returns:
float: Accuracy of logits w.r.t. given labels.
"""
if split_idx is not None:
return (logits.argmax(1)[split_idx] == labels[split_idx]).float().mean().item()
else:
return (logits.argmax(1) == labels).float().mean().item()