-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
69c67df
commit 0d3d2d5
Showing
1 changed file
with
81 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,93 @@ | ||
from bamt.utils.MathUtils import get_brave_matrix, get_proximity_matrix | ||
import math | ||
Check warning on line 1 in bamt/networks/big_brave_bn.py GitHub Actions / Qodana for PythonUnsatisfied package requirements
|
||
|
||
import numpy as np | ||
import pandas as pd | ||
from sklearn.metrics import mutual_info_score | ||
from sklearn.preprocessing import OrdinalEncoder | ||
|
||
|
||
class BigBraveBN: | ||
def __init__(self, n_nearest=5, threshold=0.3, proximity_metric="MI"): | ||
self.n_nearest = n_nearest | ||
self.threshold = threshold | ||
self.proximity_metric = proximity_metric | ||
def __init__(self): | ||
self.possible_edges = [] | ||
|
||
def set_possible_edges_by_brave(self, df): | ||
"""Returns list of possible edges for structure learning | ||
def set_possible_edges_by_brave( | ||
self, | ||
df: pd.DataFrame, | ||
n_nearest: int = 5, | ||
threshold: float = 0.3, | ||
proximity_metric: str = "MI", | ||
) -> None: | ||
"""Sets possible edges for structure learning based on Brave coefficients.""" | ||
proximity_matrix = self._get_proximity_matrix(df, proximity_metric) | ||
brave_matrix = self._get_brave_matrix(df.columns, proximity_matrix, n_nearest) | ||
|
||
threshold_value = brave_matrix.max(numeric_only=True).max() * threshold | ||
filtered_brave_matrix = brave_matrix[brave_matrix > threshold_value].stack() | ||
self.possible_edges = filtered_brave_matrix.index.tolist() | ||
|
||
@staticmethod | ||
def _get_n_nearest( | ||
data: pd.DataFrame, columns: list, corr: bool = False, number_close: int = 5 | ||
) -> list: | ||
"""Returns N nearest neighbors for every column of dataframe.""" | ||
groups = [] | ||
for c in columns: | ||
close_ind = data[c].sort_values(ascending=not corr).index.tolist() | ||
groups.append(close_ind[: number_close + 1]) | ||
return groups | ||
|
||
Args: | ||
df (DataFrame): data | ||
@staticmethod | ||
def _get_proximity_matrix(df: pd.DataFrame, proximity_metric: str) -> pd.DataFrame: | ||
"""Returns matrix of proximity for the dataframe.""" | ||
encoder = OrdinalEncoder() | ||
df_coded = df.copy() | ||
columns_to_encode = list(df_coded.select_dtypes(include=["category", "object"])) | ||
df_coded[columns_to_encode] = encoder.fit_transform(df_coded[columns_to_encode]) | ||
|
||
Returns: | ||
Possible edges: list of possible edges | ||
""" | ||
if proximity_metric == "MI": | ||
df_distance = pd.DataFrame( | ||
np.zeros((len(df.columns), len(df.columns))), | ||
columns=df.columns, | ||
index=df.columns, | ||
) | ||
for c1 in df.columns: | ||
for c2 in df.columns: | ||
dist = mutual_info_score(df_coded[c1].values, df_coded[c2].values) | ||
df_distance.loc[c1, c2] = dist | ||
return df_distance | ||
|
||
proximity_matrix = get_proximity_matrix( | ||
df, proximity_metric=self.proximity_metric | ||
elif proximity_metric == "pearson": | ||
return df_coded.corr(method="pearson") | ||
|
||
def _get_brave_matrix( | ||
self, df_columns: pd.Index, proximity_matrix: pd.DataFrame, n_nearest: int = 5 | ||
) -> pd.DataFrame: | ||
"""Returns matrix of Brave coefficients for the DataFrame.""" | ||
brave_matrix = pd.DataFrame( | ||
np.zeros((len(df_columns), len(df_columns))), | ||
columns=df_columns, | ||
index=df_columns, | ||
) | ||
groups = self._get_n_nearest( | ||
proximity_matrix, df_columns.tolist(), corr=True, number_close=n_nearest | ||
) | ||
brave_matrix = get_brave_matrix(df.columns, proximity_matrix, self.n_nearest) | ||
|
||
possible_edges_list = [] | ||
for c1 in df_columns: | ||
for c2 in df_columns: | ||
a = b = c = d = 0.0 | ||
if c1 != c2: | ||
for g in groups: | ||
a += (c1 in g) & (c2 in g) | ||
b += (c1 in g) & (c2 not in g) | ||
c += (c1 not in g) & (c2 in g) | ||
d += (c1 not in g) & (c2 not in g) | ||
|
||
for c1 in df.columns: | ||
for c2 in df.columns: | ||
if ( | ||
brave_matrix.loc[c1, c2] | ||
> brave_matrix.max(numeric_only=True).max() * self.threshold | ||
): | ||
possible_edges_list.append((c1, c2)) | ||
divisor = (math.sqrt((a + c) * (b + d))) * ( | ||
math.sqrt((a + b) * (c + d)) | ||
) | ||
br = (a * len(groups) + (a + c) * (a + b)) / ( | ||
divisor if divisor != 0 else 0.0000000001 | ||
) | ||
brave_matrix.loc[c1, c2] = br | ||
|
||
self.possible_edges = possible_edges_list | ||
return brave_matrix |