-
Notifications
You must be signed in to change notification settings - Fork 0
/
random_oversampling_neigbhorhood.py
33 lines (27 loc) · 1.13 KB
/
random_oversampling_neigbhorhood.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import sys
sys.path.append('LORE')
import numpy as np
from LORE.neighbor_generator import *
class RandomOversamplingNeighborhood():
def __init__(self,
X,
y,
model,
dataset):
self.X = X
self.y = y
self.model = model
self.dataset = dataset
def fit(self):
dfZ, _ = dataframe2explain(self.X, self.dataset, 0, self.model)
self.dfZ = dfZ
def neighborhoodSampling(self, x, N_samples):
# generating random oversampling neighborhood data
Z_df, Z = random_oversampling(self.dfZ, x, self.model, self.dataset, N_samples)
sampled_data = Z_df[self.dataset['feature_names']].values
neighborhood_data = np.r_[x.reshape(1, -1), sampled_data]
# predicting the label and probability of the neighborhood data
neighborhood_labels = self.model.predict(neighborhood_data)
neighborhood_proba = self.model.predict_proba(neighborhood_data)
neighborhood_proba = neighborhood_proba[:, neighborhood_labels[0]]
return neighborhood_data, neighborhood_labels, neighborhood_proba