Skip to content

Commit 3ffc95e

Browse files
author
Devin Lu
committed
feat(eda): added target analysis given numerical target column
feat(eda): added basic numerical target analysis squash this feat(eda): added target analysis given numerical target column
1 parent d56861e commit 3ffc95e

File tree

4 files changed

+160
-29
lines changed

4 files changed

+160
-29
lines changed

dataprep/eda/create_diff_report/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
def create_diff_report(
2424
df_list: Union[List[pd.DataFrame], Dict[str, pd.DataFrame]],
25+
target: Optional[str] = None,
2526
config: Optional[Dict[str, Any]] = None,
2627
display: Optional[List[str]] = None,
2728
title: Optional[str] = "DataPrep Report",
@@ -63,7 +64,7 @@ def create_diff_report(
6364
_suppress_warnings()
6465
cfg = Config.from_dict(display, config)
6566

66-
components = format_diff_report(df_list, cfg, mode, progress)
67+
components = format_diff_report(df_list, cfg, mode, progress, target)
6768

6869
dict_stats = defaultdict(list)
6970

dataprep/eda/create_diff_report/diff_formatter.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def format_diff_report(
6868
cfg: Config,
6969
mode: Optional[str],
7070
progress: bool = True,
71+
target: Optional[str] = None
7172
) -> Dict[str, Any]:
7273
"""
7374
Format the data and figures needed by create_diff_report
@@ -110,13 +111,26 @@ def format_diff_report(
110111
if mode == "basic":
111112
# note: we need the type ignore comment for mypy otherwise it complains because
112113
# it doesn't realize that we converted df_list to a list if it's a dictionary
113-
report = format_basic(df_list, cfg) # type: ignore
114+
if target:
115+
validate_target(target, df_list)
116+
report = format_basic(df_list, target, cfg) # type: ignore
114117
else:
115118
raise ValueError(f"Unknown mode: {mode}")
116119
return report
117120

121+
def validate_target(target: str, df_list: List[pd.DataFrame]):
122+
"""
123+
Helper function, verify that target column exists
124+
"""
125+
exists = False
126+
for df in df_list:
127+
if target in df.columns:
128+
exists = True
129+
break
130+
if not exists:
131+
raise ValueError(f'Sorry, {target} is not a valid column')
118132

119-
def format_basic(df_list: List[pd.DataFrame], cfg: Config) -> Dict[str, Any]:
133+
def format_basic(df_list: List[pd.DataFrame], target: Optional[str], cfg: Config) -> Dict[str, Any]:
120134
"""
121135
Format basic version.
122136
@@ -158,7 +172,7 @@ def format_basic(df_list: List[pd.DataFrame], cfg: Config) -> Dict[str, Any]:
158172
# data = dask.compute(data)
159173
delayed_results.append(data)
160174

161-
res_plots = dask.delayed(_format_plots)(cfg=cfg, df_list=df_list)
175+
res_plots = dask.delayed(_format_plots)(cfg=cfg, df_list=df_list, target=target)
162176

163177
dask_results["df_computations"] = delayed_results
164178
dask_results["plots"] = res_plots
@@ -211,7 +225,7 @@ def basic_computations(df: EDAFrame, cfg: Config) -> Dict[str, Any]:
211225

212226

213227
def compute_plot_data(
214-
df_list: List[dd.DataFrame], cfg: Config, dtype: Optional[DTypeDef]
228+
pd_list: List[pd.DataFrame], cfg: Config, dtype: Optional[DTypeDef], target: Optional[str]
215229
) -> Intermediate:
216230
"""
217231
Compute function for create_diff_report's plots
@@ -229,6 +243,10 @@ def compute_plot_data(
229243
"""
230244
# pylint: disable=too-many-branches, too-many-locals
231245

246+
df_list = list(map(to_dask, pd_list))
247+
for i, _ in enumerate(df_list):
248+
df_list[i].columns = df_list[i].columns.astype(str)
249+
232250
dfs = Dfs(df_list)
233251
dfs_cols = dfs.columns.apply("to_list").data
234252

@@ -277,7 +295,7 @@ def compute_plot_data(
277295
elif is_dtype(dtp, DateTime_v1()):
278296
plot_data.append((col, dtp, dask.compute(*datum), orig)) # workaround
279297

280-
return Intermediate(data=plot_data, stats=stats, visual_type="comparison_grid")
298+
return Intermediate(data=plot_data, stats=stats, visual_type="comparison_grid", target=target, df_list=pd_list)
281299

282300

283301
def _compute_variables(df: EDAFrame, cfg: Config) -> Dict[str, Any]:
@@ -407,14 +425,11 @@ def _format_variables(df: EDAFrame, cfg: Config, data: Dict[str, Any]) -> Dict[s
407425

408426

409427
def _format_plots(
410-
df_list: Union[List[pd.DataFrame], Dict[str, pd.DataFrame]], cfg: Config
428+
df_list: Union[List[pd.DataFrame], Dict[str, pd.DataFrame]], cfg: Config, target: Optional[str]
411429
) -> Dict[str, Any]:
412430
"""Formatting of plots section"""
413-
df_list = list(map(to_dask, df_list))
414-
for i, _ in enumerate(df_list):
415-
df_list[i].columns = df_list[i].columns.astype(str)
416431

417-
itmdt = compute_plot_data(df_list=df_list, cfg=cfg, dtype=None)
432+
itmdt = compute_plot_data(pd_list=df_list, cfg=cfg, dtype=None, target=target)
418433
return render_diff(itmdt, cfg=cfg)
419434

420435

dataprep/eda/diff/render.py

Lines changed: 128 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
"""
22
This module implements the visualization for the plot_diff function.
33
""" # pylint: disable=too-many-lines
4+
from turtle import color
45
from typing import Any, Dict, List, Tuple, Optional
5-
6+
from sklearn.preprocessing import MinMaxScaler
67
import math
78
import numpy as np
89
import pandas as pd
10+
import dask.array as da
11+
import matplotlib.pyplot as plt
912
from bokeh.models import (
1013
HoverTool,
1114
Panel,
1215
FactorRange,
1316
)
14-
from bokeh.plotting import Figure, figure
17+
from bokeh.plotting import Figure, figure, show
1518
from bokeh.transform import dodge
1619
from bokeh.layouts import row
20+
from bokeh.models.ranges import Range1d
21+
from bokeh.models import LinearAxis
1722

1823
from ..configs import Config
1924
from ..dtypes import Continuous, DateTime, Nominal, is_dtype
@@ -78,6 +83,8 @@ def bar_viz(
7883
orig: List[str],
7984
df_labels: List[str],
8085
baseline: int,
86+
target: Optional[str] = None,
87+
df_list: Optional[List[pd.DataFrame]] = None
8188
) -> Figure:
8289
"""
8390
Render a bar chart
@@ -94,6 +101,12 @@ def bar_viz(
94101
("Source", "@orig"),
95102
]
96103

104+
col1_min = df[0][col].min()
105+
col2_min = df[1][col].min()
106+
col1_max = df[0][col].max()
107+
col2_max = df[1][col].max()
108+
y_inc = 0.05
109+
97110
if show_yticks:
98111
if len(df[baseline]) > 10:
99112
plot_width = 28 * len(df[baseline])
@@ -106,12 +119,15 @@ def bar_viz(
106119
tools="hover",
107120
x_range=list(df[baseline].index),
108121
y_axis_type=yscale,
122+
y_range=(min(col1_min, col2_min) * (1 - y_inc), max(col1_max, col2_max) * (1 + y_inc))
109123
)
110-
124+
row_names = None
111125
offset = np.linspace(-0.08 * len(df), 0.08 * len(df), len(df)) if len(df) > 1 else [0]
112126
for i, (nrow, data) in enumerate(zip(nrows, df)):
113127
data["pct"] = data[col] / nrow * 100
114128
data.index = [str(val) for val in data.index]
129+
if row_names is None:
130+
row_names = data.index
115131
data["orig"] = orig[i]
116132

117133
fig.vbar(
@@ -126,7 +142,6 @@ def bar_viz(
126142
tweak_figure(fig, "bar", show_yticks)
127143

128144
fig.yaxis.axis_label = "Count"
129-
130145
x_axis_label = ""
131146
if ttl_grps > len(df[baseline]):
132147
x_axis_label += f"Top {len(df[baseline])} of {ttl_grps} {col}"
@@ -142,6 +157,21 @@ def bar_viz(
142157

143158
if show_yticks and yscale == "linear":
144159
_format_axis(fig, 0, df[baseline].max(), "y")
160+
161+
df1, df2 = df_list[0], df_list[1]
162+
if target != col and target and col in df1.columns and col in df2.columns:
163+
col1, col2 = df_list[0][col], df_list[1][col]
164+
row_avgs_1 = []
165+
row_avgs_2 = []
166+
for names in row_names:
167+
row_avgs_1.append(df_list[0][target][col1 == names].mean())
168+
row_avgs_2.append(df_list[1][target][col2 == names].mean())
169+
170+
row_avgs_1 = [0 if math.isnan(x) else x for x in row_avgs_1]
171+
row_avgs_2 = [0 if math.isnan(x) else x for x in row_avgs_2]
172+
fig.extra_y_ranges = {"Averages": Range1d(start=min(row_avgs_1 + row_avgs_2) * (1 - y_inc), end=max(row_avgs_1 + row_avgs_2) * (1 + y_inc))}
173+
fig.multi_line([row_names, row_names], [row_avgs_1, row_avgs_2], color=['navy', 'firebrick'], y_range_name="Averages", line_width=4)
174+
fig.add_layout(LinearAxis(y_range_name="Averages"), 'right')
145175
return fig
146176

147177

@@ -155,28 +185,56 @@ def hist_viz(
155185
show_yticks: bool,
156186
df_labels: List[str],
157187
orig: Optional[List[str]] = None,
188+
target: Optional[str] = None,
189+
df_list: Optional[List[pd.DataFrame]] = None
158190
) -> Figure:
159191
"""
160192
Render a histogram
161193
"""
162194
# pylint: disable=too-many-arguments,too-many-locals
163-
164195
tooltips = [
165196
("Bin", "@intvl"),
166197
("Frequency", "@freq"),
167198
("Percent", "@pct{0.2f}%"),
168199
("Source", "@orig"),
169200
]
201+
df1, df2 = df_list[0], df_list[1]
202+
y_inc = 0.05
203+
tooltips = [
204+
("Bin", "@intvl"),
205+
("Frequency", "@freq"),
206+
("Percent", "@pct{0.2f}%"),
207+
("Source", "@orig"),
208+
]
209+
fig = None
210+
211+
y_start, y_end = None, None
212+
counts_list = []
213+
if target and target != col and col in df1.columns and col in df2.columns:
214+
for hst in hist:
215+
counts, bins = hst
216+
counts_list.append(counts)
217+
218+
counts_min_1 = min(counts_list[0])
219+
counts_min_2 = min(counts_list[1])
220+
221+
counts_max_1 = max(counts_list[0])
222+
counts_max_2 = max(counts_list[1])
223+
224+
y_start, y_end = min(counts_min_1, counts_min_2), max(counts_max_1, counts_max_2)
225+
226+
170227
fig = Figure(
171228
plot_height=plot_height,
172229
plot_width=plot_width,
173230
title=col,
174231
toolbar_location=None,
175-
y_axis_type=yscale,
232+
y_axis_type=yscale
176233
)
177-
234+
bins_list = []
178235
for i, hst in enumerate(hist):
179236
counts, bins = hst
237+
bins_list.append(bins)
180238
if sum(counts) == 0:
181239
fig.rect(x=0, y=0, width=0, height=0)
182240
continue
@@ -192,16 +250,34 @@ def hist_viz(
192250
}
193251
)
194252
bottom = 0 if yscale == "linear" or df.empty else counts.min() / 2
195-
fig.quad(
196-
source=df,
197-
left="left",
198-
right="right",
199-
bottom=bottom,
200-
alpha=0.5,
201-
top="freq",
202-
fill_color=CATEGORY10[i],
203-
line_color=CATEGORY10[i],
204-
)
253+
if y_start is not None and y_end is not None:
254+
# fig.y_range = (y_start * (1 - y_inc), y_end * (1 + y_inc))
255+
fig.extra_y_ranges = {"Counts": Range1d(start=y_start * (1 - y_inc), end=y_end * (1 + y_inc))}
256+
fig.quad(
257+
source=df,
258+
left="left",
259+
right="right",
260+
bottom=bottom,
261+
alpha=0.5,
262+
top="freq",
263+
fill_color=CATEGORY10[i],
264+
line_color=CATEGORY10[i],
265+
y_range_name="Counts"
266+
)
267+
else:
268+
fig.quad(
269+
source=df,
270+
left="left",
271+
right="right",
272+
bottom=bottom,
273+
alpha=0.5,
274+
top="freq",
275+
fill_color=CATEGORY10[i],
276+
line_color=CATEGORY10[i]
277+
)
278+
# if col == 'LotFrontage':
279+
# breakpoint()
280+
205281
hover = HoverTool(tooltips=tooltips, attachment="vertical", mode="vline")
206282
fig.add_tools(hover)
207283

@@ -224,6 +300,34 @@ def hist_viz(
224300
fig.xaxis.axis_label = x_axis_label
225301
fig.xaxis.axis_label_standoff = 0
226302

303+
if target and target != col and col in df1.columns and col in df2.columns:
304+
col1, col2 = df1[col], df2[col]
305+
source1, source2 = col1, col2
306+
col1 = col1[~np.isnan(col1)]
307+
col2 = col2[~np.isnan(col2)]
308+
num_bins1 = len(bins_list[0]) - 1
309+
num_bins2 = len(bins_list[1]) - 1
310+
bins_1, bins_2 = bins_list[0], bins_list[1]
311+
312+
df1_source_bins_series = pd.cut(source1, bins=bins_1, labels=False)
313+
df1_bin_averages = [None] * num_bins1
314+
315+
df2_source_bins_series = pd.cut(source2, bins=bins_2, labels=False)
316+
df2_bin_averages = [None] * num_bins2
317+
318+
for b in range(num_bins1):
319+
df1_bin_averages[b] = df1[target][df1_source_bins_series == b].mean()
320+
for b in range(num_bins2):
321+
df2_bin_averages[b] = df2[target][df2_source_bins_series == b].mean()
322+
323+
df1_bin_averages = [0 if math.isnan(x) else x for x in df1_bin_averages]
324+
df2_bin_averages = [0 if math.isnan(x) else x for x in df2_bin_averages]
325+
max_range = max(df1_bin_averages + df2_bin_averages)
326+
min_range = min(df1_bin_averages + df2_bin_averages)
327+
328+
fig.extra_y_ranges['Averages'] = Range1d(start=min_range * (1 - y_inc), end=max_range * (1 + y_inc))
329+
fig.multi_line([bins_1, bins_2], [df1_bin_averages, df2_bin_averages], color=['navy', 'firebrick'], y_range_name="Averages", line_width=4)
330+
fig.add_layout(LinearAxis(y_range_name="Averages", axis_label='Bin Averages'), 'right')
227331
return fig
228332

229333

@@ -610,6 +714,9 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
610714
nrows = itmdt["stats"]["nrows"]
611715
titles: List[str] = []
612716

717+
df_list = itmdt.df_list
718+
target = itmdt.target
719+
613720
for col, dtp, data, orig in itmdt["data"]:
614721
fig = None
615722
if is_dtype(dtp, Nominal()):
@@ -626,6 +733,8 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
626733
orig,
627734
df_labels,
628735
baseline if len(df) > 1 else 0,
736+
target,
737+
df_list
629738
)
630739
elif is_dtype(dtp, Continuous()):
631740
if cfg.diff.density:
@@ -643,6 +752,8 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
643752
False,
644753
df_labels,
645754
orig,
755+
target,
756+
df_list
646757
)
647758
elif is_dtype(dtp, DateTime()):
648759
df, timeunit = data
@@ -760,7 +871,6 @@ def render_diff(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
760871
cfg
761872
Config instance
762873
"""
763-
764874
if itmdt.visual_type == "comparison_grid":
765875
visual_elem = render_comparison_grid(itmdt, cfg)
766876
if itmdt.visual_type == "comparison_continuous":

dataprep/eda/intermediate.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
3030
visual_type = kwargs.pop("visual_type")
3131
super().__init__(**kwargs)
3232
self.visual_type = visual_type
33+
if 'target' in kwargs:
34+
self.target = kwargs.pop('target')
35+
36+
if 'df_list' in kwargs:
37+
self.df_list = kwargs.pop('df_list')
3338
else:
3439
raise ValueError("Unsupported initialization")
3540

0 commit comments

Comments
 (0)