Skip to content

Commit

Permalink
Merge pull request #146 from realratchet/master
Browse files Browse the repository at this point in the history
Fix imputations
realratchet authored Mar 8, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents 32fc7b3 + 09f4353 commit 3bdf7ff
Showing 2 changed files with 17 additions and 6 deletions.
21 changes: 16 additions & 5 deletions tablite/imputation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math

from tablite.base import BaseTable, Column
from tablite.utils import sub_cls_check
from tablite.utils import sub_cls_check, summary_statistics
from tablite.config import Config
from tablite import sort_utils

@@ -90,11 +90,11 @@ def imputation(T, targets, missing=None, method="carry forward", sources=None, t
methods = ["nearest neighbour", "mean", "mode", "carry forward"]

if method == "carry forward":
return carry_forward(T, targets, missing, tqdm=_tqdm, pbar=None)
return carry_forward(T, targets, missing, tqdm=tqdm, pbar=pbar)
elif method in {"mean", "mode"}:
return stats_method(T, targets, missing, method, tqdm=_tqdm, pbar=None)
return stats_method(T, targets, missing, method, tqdm=tqdm, pbar=pbar)
elif method == "nearest neighbour":
return nearest_neighbour(T, sources, missing, targets, tqdm=_tqdm, pbar=None)
return nearest_neighbour(T, sources, missing, targets, tqdm=tqdm, pbar=pbar)
else:
raise ValueError(f"method {method} not recognised amonst known methods: {list(methods)})")

@@ -136,7 +136,18 @@ def stats_method(T, targets, missing, method, tqdm=_tqdm, pbar=None):
if name in targets:
col = T.columns[name]
assert isinstance(col, Column)
stats = col.statistics()

hist_values, hist_counts = col.histogram()

for m in missing:
try:
idx = hist_values.index(m)
hist_counts[idx] = 0
except ValueError:
pass

stats = summary_statistics(hist_values, hist_counts)

new_value = stats[method]
col.replace(mapping={m: new_value for m in missing})
new[name] = col
2 changes: 1 addition & 1 deletion tablite/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
major, minor, patch = 2023, 10, 7
major, minor, patch = 2023, 10, 8
__version_info__ = (major, minor, patch)
__version__ = ".".join(str(i) for i in __version_info__)

0 comments on commit 3bdf7ff

Please sign in to comment.