-
Notifications
You must be signed in to change notification settings - Fork 0
/
sampling_methods.py
73 lines (55 loc) · 2.74 KB
/
sampling_methods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from abc import ABCMeta, abstractmethod
from statistics import mode
import torch
import numpy as np
def calc_entropy(probs, num_classes=10):
e_term = 0.000001 # To avoid log(0) problems
log_prob = torch.log(probs + e_term)
entropy = -(probs * log_prob).sum(dim=1)
normalized_entropy = entropy/torch.log(torch.tensor(num_classes))
return normalized_entropy
def select_top_k(indices, num_samples):
return indices[:num_samples]
def select_uniform(indices, num_samples, entropies=None):
total_samples = indices.shape[0]
selection_idx = np.arange(0, total_samples, int(total_samples/num_samples))[:num_samples]
return indices[selection_idx]
selection_func_mapping = {"top_k": select_top_k, "uniform": select_uniform}
class UncertaintySampling(metaclass=ABCMeta):
def __init__(self, dataset_indices, select_crit='top_k', discount_factor=0, num_classes=10, use_target=False, target_entropies=None):
print(f"Prioritizer Initialization")
print(f"Selection Criterion: {select_crit}")
self.discount_factor = discount_factor
self.dataset_indices = dataset_indices # To map from original dataset to training subset
self.dataset_len = len(self.dataset_indices)
self.target_entropies = target_entropies
self.num_classes = num_classes
self.use_target = use_target
self.selection_function = selection_func_mapping[select_crit]
def get_indices(self, indices_ordered, num_samples, epoch):
return self.selection_function(indices_ordered, num_samples)
@abstractmethod
def get_uncertainty(self):
pass
class LabelEntropy(UncertaintySampling):
def get_uncertainty(self, probs):
entropy = calc_entropy(probs, self.num_classes)
return entropy
def query(self, samples_uncertainty, num_samples, epoch):
# Ordered indices of both current and target models stored
if self.use_target:
target_entropies_epoch = self.target_entropies[:, epoch]
model_indices_ordered = torch.argsort(target_entropies_epoch[self.dataset_indices], descending=True)
else:
model_indices_ordered = torch.argsort(samples_uncertainty[self.dataset_indices], descending=True)
return self.selection_function(model_indices_ordered, num_samples)
class RandomSampling(UncertaintySampling):
def query(self, samples_uncertainty, num_samples, epoch):
random_indices = torch.randperm(self.dataset_len)
return random_indices[:num_samples]
def get_uncertainty(self, probs):
entropy = calc_entropy(probs, self.num_classes)
return entropy
def prioritizer_factory(al_method):
method_mapping = {'entropy': LabelEntropy, 'random': RandomSampling}
return method_mapping[al_method]