Skip to content

Commit

Permalink
Final modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
zazass8 committed Oct 15, 2024
1 parent 97ae9cf commit 260517e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
9 changes: 7 additions & 2 deletions mlxtend/frequent_patterns/association_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import numpy as np
import pandas as pd

from ..frequent_patterns import fpcommon as fpc

_metrics = [
"antecedent support",
"consequent support",
Expand Down Expand Up @@ -112,6 +114,9 @@ def association_rules(
https://rasbt.github.io/mlxtend/user_guide/frequent_patterns/association_rules/
"""
# check for valid input
fpc.valid_input_check(df_or, null_values)

if not df.shape[0]:
raise ValueError(
"The input DataFrame `df` containing " "the frequent itemsets is empty."
Expand All @@ -125,8 +130,8 @@ def association_rules(
)

def kulczynski_helper(sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_):
conf_AC = sAC / sA
conf_CA = sAC / sC
conf_AC = sAC * (num_itemsets - disAC) / (sA * (num_itemsets - disA) - dis_int)
conf_CA = sAC * (num_itemsets - disAC) / (sC * (num_itemsets - disC) - dis_int_)
kulczynski = (conf_AC + conf_CA) / 2
return kulczynski

Expand Down
26 changes: 21 additions & 5 deletions mlxtend/frequent_patterns/fpcommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,15 @@ def setup_fptree(df, min_support):
return tree, disabled, rank


def generate_itemsets(
generator, df_or, disabled, min_support, num_itemsets, colname_map
):
def generate_itemsets(generator, df, disabled, min_support, num_itemsets, colname_map):
itemsets = []
supports = []
df = df_or.copy().values
for sup, iset in generator:
itemsets.append(frozenset(iset))
# select data of iset from disabled dataset
dec = disabled[:, iset]
# select data of iset from original dataset
_dec = df[:, iset]
_dec = df.values[:, iset]

# case if iset only has one element
if len(iset) == 1:
Expand Down Expand Up @@ -163,6 +160,19 @@ def valid_input_check(df, null_values=False):
"Please use a DataFrame with bool type",
DeprecationWarning,
)

# If null_values is True but no NaNs are found, raise an error
has_nans = pd.isna(df).any().any()
if null_values and not has_nans:
raise ValueError(
"null_values=True is not permitted when there are no NaN values in the DataFrame."
)
# If null_values is False but NaNs are found, raise an error
if not null_values and has_nans:
raise ValueError(
"NaN values are not permitted in the DataFrame when null_values=False."
)

# Pandas is much slower than numpy, so use np.where on Numpy arrays
if hasattr(df, "sparse"):
if df.size == 0:
Expand All @@ -185,6 +195,12 @@ def valid_input_check(df, null_values=False):
"The allowed values for a DataFrame"
" are True, False, 0, 1. Found value %s" % (val)
)

if null_values:
s = (
"The allowed values for a DataFrame"
" are True, False, 0, 1, NaN. Found value %s" % (val)
)
raise ValueError(s)


Expand Down
10 changes: 6 additions & 4 deletions mlxtend/frequent_patterns/tests/test_association_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ def test_on_df_with_missing_entries():
],
}

df = pd.DataFrame(dict)
df_missing = pd.DataFrame(dict)

numpy_assert_raises(KeyError, association_rules, df, df, len(df))
numpy_assert_raises(KeyError, association_rules, df_missing, df, len(df))


def test_on_df_with_missing_entries_support_only():
Expand Down Expand Up @@ -339,8 +339,10 @@ def test_on_df_with_missing_entries_support_only():
],
}

df = pd.DataFrame(dict)
df_result = association_rules(df, df, len(df), support_only=True, min_threshold=0.1)
df_missing = pd.DataFrame(dict)
df_result = association_rules(
df_missing, df, len(df), support_only=True, min_threshold=0.1
)

assert df_result["support"].shape == (18,)
assert int(np.isnan(df_result["support"].values).any()) != 1
Expand Down

0 comments on commit 260517e

Please sign in to comment.