Skip to content

Commit

Permalink
Feature: Filling gaps (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman223 authored Aug 15, 2023
1 parent fc3362c commit 43fba0f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
56 changes: 49 additions & 7 deletions bamt/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,13 @@ def __init__(self, type: str):
weights[tuple_key] = input_dict["weights"][str(tuple_key)]
self.weights = weights

def fit_parameters(self, data: pd.DataFrame, dropna: bool = True, n_jobs: int = -1):
def fit_parameters(self, data: pd.DataFrame, n_jobs: int = -1):
"""
Base function for parameter learning
"""
if dropna:
data = data.dropna()
data.reset_index(inplace=True, drop=True)
if data.isnull().values.any():
logger_network.error("Dataframe contains NaNs.")
return

if not os.path.isdir(STORAGE):
os.makedirs(STORAGE)
Expand Down Expand Up @@ -539,13 +539,15 @@ def sample(
evidence: Optional[Dict[str, Union[str, int, float]]] = None,
as_df: bool = True,
predict: bool = False,
parall_count: int = 1,
parall_count: int = -1,
filter_neg: bool = True,
) -> Union[None, pd.DataFrame, List[Dict[str, Union[str, int, float]]]]:
"""
Sampling from Bayesian Network
n: int number of samples
evidence: values for nodes from user
parall_count: number of threads. Defaults to 1.
filter_neg: either filter negative vals or not.
"""
from joblib import Parallel, delayed

Expand Down Expand Up @@ -632,8 +634,10 @@ def wrapper():
positive_columns = [
c for c in cont_nodes if self.descriptor["signs"][c] == "pos"
]
seq_df = seq_df[(seq_df[positive_columns] >= 0).all(axis=1)]
seq_df.reset_index(inplace=True, drop=True)
if filter_neg:
seq_df = seq_df[(seq_df[positive_columns] >= 0).all(axis=1)]
seq_df.reset_index(inplace=True, drop=True)

seq = seq_df.to_dict("records")
sample_output = pd.DataFrame.from_dict(seq, orient="columns")

Expand Down Expand Up @@ -816,6 +820,44 @@ def find_family(
plot_(
plot_to, [self[name] for name in structure["nodes"]], structure["edges"]
)

def fill_gaps(self, df: pd.DataFrame, **kwargs):
"""
Fill NaNs with sampled values.
:param df: dataframe with NaNs
:param kwargs: the same params as bn.predict
:return df, failed: filled DataFrame and list of failed rows (sometimes predict can return np.nan)
"""
if not self.distributions:
logger_network.error("To call this method you must train parameters.")

# create a mimic row to get a dataframe from iloc method
list = [np.nan for _ in range(df.shape[1])]
df.loc["mimic"] = list

def fill_row(df, i):
row = df.iloc[[i, -1], :].drop(["mimic"], axis=0)

evidences = row.dropna(axis=1)

return row.index[0], self.predict(evidences, progress_bar=False, **kwargs)

failed = []
for index in range(df.shape[0] - 1):
if df.iloc[index].isna().any():
true_pos, result = fill_row(df, index)
if any(pd.isna(v[0]) for v in result.values()):
failed.append(true_pos)
continue
else:
for column, value in result.items():
df.loc[true_pos, column] = value[0]
else:
continue
df.drop(failed, inplace=True)
return df.drop(["mimic"]), failed

def get_dist(self, node_name: str, pvals: Optional[dict] = None):
"""
Expand Down
6 changes: 1 addition & 5 deletions bamt/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,11 @@ def scan(self, data: DataFrame):
if not columns_disc:
logger_preprocessor.info("No one column is discrete")

def apply(self, data: DataFrame, dropna: bool = True) -> Tuple[DataFrame, Dict]:
def apply(self, data: DataFrame) -> Tuple[DataFrame, Dict]:
"""
Apply pipeline
data: data to apply on
dropna: drop NaNs with pandas dropna
"""
if dropna:
data = data.dropna()
data.reset_index(inplace=True, drop=True)
df = data.copy()
self.nodes_types = self.get_nodes_types(data)
if list(self.nodes_types.keys()) != data.columns.to_list():
Expand Down
3 changes: 3 additions & 0 deletions tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def prepare_bn_and_data(self):
[("encoder", encoder), ("discretizer", discretizer)]
)

hack_data.dropna(inplace=True)
hack_data.reset_index(inplace=True, drop=True)

discretized_data, est = p.apply(hack_data)

self.bn = HybridBN(has_logit=True)
Expand Down

0 comments on commit 43fba0f

Please sign in to comment.